# GAN

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os

In [6]:
class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.channels, self.img_rows, self.img_cols)  # PyTorch是通道优先
        self.latent_dim = 100
        self.img_size = self.img_rows * self.img_cols * self.channels
        
        # init
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # gen model
        self.generator = self.build_generator().to(self.device)
        self.discriminator = self.build_discriminator().to(self.device)
        
        # loss
        self.adversarial_loss = nn.BCELoss()
        
        # opti
        self.optimizer_G = optim.Adam(self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_D = optim.Adam(self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        
        # create folder path
        if not os.path.exists('images'):
            os.makedirs('images')
    
    def build_generator(self):
        model = nn.Sequential(
            nn.Linear(self.latent_dim, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm1d(256, momentum=0.8),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm1d(512, momentum=0.8),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm1d(1024, momentum=0.8),
            nn.Linear(1024, self.img_size),  # output 28*28*1=784
            nn.Tanh(),
            nn.Unflatten(1, (self.channels, self.img_rows, self.img_cols))# (1, 28, 28)
    )
        
        # show model
        print("Generator Architecture:")
        print(model)
        return model
    
    def build_discriminator(self):
        model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(int(np.prod(self.img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
        # show model
        print("\nDiscriminator Architecture:")
        print(model)
        return model
    
    def load_data(self, batch_size=128):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))  # normalize to [-1, 1]
        ])
        
        train_dataset = torchvision.datasets.MNIST(
            root='./data', 
            train=True, 
            download=True, 
            transform=transform
        )
        
        train_loader = DataLoader(
            train_dataset, 
            batch_size=batch_size, 
            shuffle=True,
            num_workers=2
        )
        
        return train_loader
    
    def train(self, epochs, batch_size=128, sample_interval=50):
        train_loader = self.load_data(batch_size)
        
        for epoch in range(epochs):
            for i, (real_imgs, _) in enumerate(train_loader):
                # init real image
                real_imgs = real_imgs.to(self.device)
                batch_size = real_imgs.size(0)
                
                # label
                valid = torch.ones(batch_size, 1, device=self.device)
                fake = torch.zeros(batch_size, 1, device=self.device)
                
                # ====================
                # train discriminator
                # ====================
                self.optimizer_D.zero_grad()
                
                # real img loss
                real_loss = self.adversarial_loss(self.discriminator(real_imgs), valid)
                
                # gen fake img
                z = torch.randn(batch_size, self.latent_dim, device=self.device)
                gen_imgs = self.generator(z)
                
                # fake img loss
                fake_loss = self.adversarial_loss(self.discriminator(gen_imgs.detach()), fake)
                
                # discriminator loss
                d_loss = (real_loss + fake_loss) / 2
                d_loss.backward()
                self.optimizer_D.step()
                
                # accuracy of discriminator
                real_accuracy = torch.mean((self.discriminator(real_imgs) > 0.5).float()).item()
                fake_accuracy = torch.mean((self.discriminator(gen_imgs.detach()) < 0.5).float()).item()
                d_accuracy = 0.5 * (real_accuracy + fake_accuracy)
                
                # ====================
                # train gen
                # ====================
                self.optimizer_G.zero_grad()
                
                # gen new img
                z = torch.randn(batch_size, self.latent_dim, device=self.device)
                gen_imgs = self.generator(z)
                
                # gen loss：confuse discriminator
                g_loss = self.adversarial_loss(self.discriminator(gen_imgs), valid)
                g_loss.backward()
                self.optimizer_G.step()
                
            if epoch % 10 == 0:
                print(f"[Epoch {epoch}/{epochs}] "
                      f"[D loss: {d_loss.item():.4f}, acc: {d_accuracy*100:.2f}%] "
                      f"[G loss: {g_loss.item():.4f}]")
            
            # save img sample
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
    
    def sample_images(self, epoch):
        # save gen sample
        self.generator.eval()
        with torch.no_grad():
            # add noise
            z = torch.randn(25, self.latent_dim, device=self.device)
            gen_imgs = self.generator(z)
            
            # transmit range [-1, 1] to [0, 1]
            gen_imgs = 0.5 * gen_imgs + 0.5
            
            # visualization
            fig, axs = plt.subplots(5, 5, figsize=(10, 10))
            cnt = 0
            for i in range(5):
                for j in range(5):
                    # to numpy
                    img = gen_imgs[cnt].cpu().numpy() # (1, 28, 28)
                    img = np.transpose(img, (1, 2, 0)) # (28, 28, 1)
                    img = img.squeeze()  # (1, 28, 28) -> (28, 28)
                    img = np.clip(img, 0, 1)
                    
                    axs[i, j].imshow(img, cmap='gray')
                    axs[i, j].axis('off')
                    cnt += 1
            
            fig.suptitle(f"Epoch {epoch}", fontsize=16)
            fig.tight_layout()
            fig.savefig(f"images/{epoch}.png")
            plt.close(fig)
        
        self.generator.train()
    
    def save_models(self, path="./models"):
        if not os.path.exists(path):
            os.makedirs(path)
        
        torch.save(self.generator.state_dict(), f"{path}/generator.pth")
        torch.save(self.discriminator.state_dict(), f"{path}/discriminator.pth")
        print(f"Models saved to {path}")
    
    def load_models(self, path="./models"):
        if os.path.exists(f"{path}/generator.pth"):
            self.generator.load_state_dict(torch.load(f"{path}/generator.pth", map_location=self.device))
            self.discriminator.load_state_dict(torch.load(f"{path}/discriminator.pth", map_location=self.device))
            print(f"Models loaded from {path}")
            return True
        return False




In [7]:
if __name__ == '__main__':
    # init
    gan = GAN()
    
    # train param
    EPOCHS = 100
    BATCH_SIZE = 32
    SAMPLE_INTERVAL = 200
    
    print("Starting training...")
    gan.train(epochs=EPOCHS, batch_size=BATCH_SIZE, sample_interval=SAMPLE_INTERVAL)
    
    # save model
    gan.save_models()

Using device: cpu
Generator Architecture:
Sequential(
  (0): Linear(in_features=100, out_features=256, bias=True)
  (1): LeakyReLU(negative_slope=0.2, inplace=True)
  (2): BatchNorm1d(256, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
  (3): Linear(in_features=256, out_features=512, bias=True)
  (4): LeakyReLU(negative_slope=0.2, inplace=True)
  (5): BatchNorm1d(512, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
  (6): Linear(in_features=512, out_features=1024, bias=True)
  (7): LeakyReLU(negative_slope=0.2, inplace=True)
  (8): BatchNorm1d(1024, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
  (9): Linear(in_features=1024, out_features=784, bias=True)
  (10): Tanh()
  (11): Unflatten(dim=1, unflattened_size=(1, 28, 28))
)

Discriminator Architecture:
Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=512, bias=True)
  (2): LeakyReLU(negative_slope=0.2, inplace=True)
  (3): Linear(in_featu

* From Keras GAN,https://github.com/eriklindernoren/Keras-GAN/tree/master/gan