In [4]:
from argparse import Namespace

import torch
import torch.nn.functional as F
import pytorch_lightning as pl

from src.config import get_parser
from src.models.discriminator import Discriminator
from src.models.generator import Generator
from src.models.stylist import Stylist
from src.data.masked_datamodule import MaskedDataModule

In [5]:
class GAN(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.automatic_optimization = False
        
        self.G = Generator(hparams)
        self.D = Discriminator(hparams)
        self.S = Stylist(hparams)
        # Renderer requires device, create in train step
        self.renderer = Renderer(hparams)
     
    def forward(self, shape, style):
        return self.G(shape, style)
    
    def adversarial_loss(self, y_hat, y):
        return F.mse_loss(y_hat, y)
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        style_img = batch['style_img']
        img_patch = batch['img_patch']
        points =  batch['points']
        normals = batch['normals']            
        
        
        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)

        # train generator
        if optimizer_idx == 0:
            style = self.S(style_img)
            vertices = self.G(points, normals, style)            
            renders =  self.renderer(vertices)
            # Normalize renders
            
            
            
            # generate images
            self.generated_imgs = self(z)

            # log sampled images
            sample_imgs = self.generated_imgs[:6]
            grid = torchvision.utils.make_grid(sample_imgs)
            self.logger.experiment.add_image('generated_images', grid, 0)

            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            # adversarial loss is binary cross-entropy
            g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # how well can it label as real?
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

            # how well can it label as fake?
            fake = torch.zeros(imgs.size(0), 1)
            fake = fake.type_as(imgs)

            fake_loss = self.adversarial_loss(
                self.discriminator(self(z).detach()), fake)

            # discriminator loss is the average of these
            d_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

    
    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2        
        opt_gs = torch.optim.Adam(list(self.G.parameters()) 
                                 + list(self.S.parameters()), 
                                 lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.D.parameters(), 
                                 lr=lr, betas=(b1, b2))
        return [opt_gs, opt_d], []

In [3]:
dm = MNISTDataModule()
model = GAN(*dm.size())
trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)
trainer.fit(model, dm)

False

In [None]:
img_dir ='/home/bobi/Desktop/db/ffhq-dataset/images1024x1024'
mask_dir = '/home/bobi/Desktop/face-parsing.PyTorch/res/masks'

ds = MaskedDataset(img_dir, mask_dir)    
ds[0]