In [1]:
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import time
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from PIL import Image
from dataset import FFHQDataset
from model import *
from loss import ProWGANDLoss
from config import Config

In [2]:
class Trainer:
    def __init__(self, cfg, dataset):
        self.cfg = cfg
        self.dataset = dataset
        print('Init G and D models')
        self.G = Generator(z_dim=cfg.z_dim, in_channels=cfg.max_channels, max_size=cfg.max_size).to(cfg.device)
        self.D = Discriminator(max_size=cfg.max_size, out_channels=cfg.max_channels).to(cfg.device)
        
        if cfg.preG_path is not None:
            print('Load pretrained G model')
            self.G = load_dict(self.G, torch.load(cfg.preG_path).state_dict(), device=cfg.device)
        if cfg.preD_path is not None:
            print('Load pretrained D model')
            self.D = load_dict(self.D, torch.load(cfg.preD_path).state_dict(), device=cfg.device)
        
        self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=cfg.lr, betas=(0.0, 0.99))
        self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=cfg.lr, betas=(0.0, 0.99))
        self.lossfn_D = ProWGANDLoss().to(cfg.device)
        
        if cfg.fixed_z_path is None:
            self.fixed_z = torch.normal(cfg.noise_mean, cfg.noise_std, size=(64, cfg.z_dim, 1, 1)).to(cfg.device)
        else:
            self.fixed_z = torch.load(cfg.fixed_z_path).to(cfg.device)
        
    def get_loader(self, step):
        img_size = 2 ** (step + 2)
        transforms_ = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((img_size, img_size))
        ])
        img_dataset = self.dataset(data_dir=self.cfg.data_dir, transforms=transforms_, num=self.cfg.num_dataset)
        return DataLoader(
            img_dataset,
            batch_size=3*1024//4**step, # n*[1024, 256, 64, 16, 4, 2] when imgsize from 4 to 128
            shuffle=False,
            num_workers=10
        )
        
    def train(self):
        try:
            print('Start training')
            for step in range(self.cfg.current_step, self.cfg.steps):
                loader = self.get_loader(step)
                start_epoch = self.cfg.current_epoch if self.cfg.current_step == step else 0
                alpha = start_epoch*2/self.cfg.epoches if self.cfg.current_step == step else 1e-8
                for epoch in range(start_epoch, self.cfg.epoches):
                    begin_time = time.time()
                    for i, img in enumerate(tqdm(loader, desc=f'epoch {epoch+1}/{self.cfg.epoches}:')):
                        z = torch.normal(self.cfg.noise_mean, self.cfg.noise_std, size=(img.size(0), self.cfg.z_dim, 1, 1))
                        img, z = img.to(self.cfg.device), z.to(self.cfg.device)

                        # train D
                        self.D.zero_grad()
                        loss_D = self.lossfn_D(self.G, self.D, z, img, step, alpha)
                        loss_D.backward()
                        self.optimizer_D.step()

                        # train G
                        self.G.zero_grad()
                        loss_G = -self.D(self.G(z, step=step, alpha=alpha), step=step, alpha=alpha).mean()
                        loss_G.backward()
                        self.optimizer_G.step()

                        # smooth increase alpha
                        # it reaches 1 after half of epoches
                        alpha += 2 * img.size(0) / (self.cfg.epoches * self.cfg.num_dataset)
                        alpha = min(alpha, 1)
                    end_time = time.time()
                    self.log(begin_time, end_time, step, epoch, alpha, loss_G, loss_D, log_path=self.cfg.log_path)
                    self.test(step, epoch, alpha, imgs_dir=self.cfg.imgs_dir, imshow=(epoch+1)%10==0)
        except Exception as e:
            print('Exception: ', e)
        finally:
            # save models
            if self.cfg.models_dir is not None:
                print('Saving models')
                torch.save(self.G, os.path.join(self.cfg.models_dir, f'G_step{step}_epoch{epoch}.pth'))
                torch.save(self.D, os.path.join(self.cfg.models_dir, f'D_step{step}_epoch{epoch}.pth'))
        
    def test(self, step, epoch, alpha, imgs_dir=None, imshow=True):
        with torch.no_grad():
            fig = plt.figure(figsize=(10,10))
            plt.axis("off")
            out_img = make_grid(self.G(self.fixed_z, step=step, alpha=alpha), padding=2, normalize=True).cpu().permute(1,2,0).numpy()
            if imshow:
                plt.imshow(out_img)
                plt.show()
            if imgs_dir is not None:
                out_img = (out_img - np.min(out_img)) * 255 / (np.max(out_img) - np.min(out_img))
                im = Image.fromarray(out_img.astype(np.uint8))
                im.save(os.path.join(imgs_dir, f'{step}-{epoch}.png'))
    
    def log(self, begin_time, end_time, step, epoch, alpha, loss_G, loss_D, log_path=None):
        out_str = '[total time: {:.5f}s] '.format(end_time-begin_time) + f'[Step: {step+1}/{self.cfg.steps}] [Epoch: {epoch+1}/{self.cfg.epoches}] [alpha: {format(alpha, ".2e")}] [G loss: {loss_G.item()}] [D loss: {loss_D.item()}]'
        print(out_str)
        if log_path is not None:
            with open(log_path, 'a') as f:
                f.write(out_str+'\n')

In [3]:
cfg = Config(
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu',
    lr = 1e-3,
    max_size = 128,
    epoches = 30,
    current_step = 4,
    current_epoch = 0,
    noise_mean = 0,
    noise_std = 1,
    z_dim = 512,
    max_channels = 512,
    num_dataset = 30000,
    preG_path = './data/trained_models/G32.pth',
    preD_path = './data/trained_models/D32.pth',
    fixed_z_path = './data/trained_models/fixed_z.pth',
    data_dir = '/data1/cgl/dataset/face/seeprettyface_yellow_face_128/thumbnails128x128',
    log_path = './data/logs/log64.txt',
    imgs_dir = './data/imgs',
    models_dir = './data/trained_models'
)

In [4]:
trainer = Trainer(cfg, FFHQDataset)

Init G and D models
Load pretrained G model
Load pretrained D model


In [5]:
trainer.train()

Start training


epoch 1/30::   0%|          | 0/2500 [00:00<?, ?it/s]

Saving models


KeyboardInterrupt: 