In [1]:
from argparse import Namespace

import torch
import torch.nn.functional as F
import pytorch_lightning as pl

from src.config import get_parser
from src.models.discriminator import Discriminator
from src.models.generator import Generator
from src.models.stylist import Stylist
from src.renderer import Renderer
from src.data.masked_datamodule import MaskedDataModule

In [2]:
class GAN(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.automatic_optimization = False
        self.mean = sum(hparams.image_mean) / len(hparams.image_mean)
        self.std = sum(hparams.image_std) / len(hparams.image_std)
        
        self.G = Generator(hparams)
        self.D = Discriminator(hparams)
        self.S = Stylist(hparams)
        # Renderer requires device, created in .to() step
        self.R = Renderer(hparams)
        
     
    def forward(self, shape, style):
        return self.G(shape, style)
    
    def adversarial_loss(self, y_hat, y):
        return F.mse_loss(y_hat, y)
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        style_img = batch['style_img']
        img_patch = batch['img_patch']
        points =  batch['points']
        normals = batch['normals']            
        bs = style_img.size(0)
        
        # train generator
        if optimizer_idx == 0:
            style = self.S(style_img)
            vertices = self.G(points, normals, style)            
            renders =  self.R(vertices)            
            renders = (renders - self.mean) / self.std

            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(bs, 1).type_as(style_img)            

            # adversarial loss is binary cross-entropy
            g_loss = self.adversarial_loss(self.D(renders), valid)
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # how well can it label as real?
            valid = torch.ones(bs, 1).type_as(style_img)         

            real_loss = self.adversarial_loss(self.D(imgs), valid)

            # how well can it label as fake?
            fake = torch.zeros(imgs.size(0), 1)
            fake = fake.type_as(imgs)
            
            style = self.S(style_img)
            vertices = self.G(points, normals, style)            
            renders =  self.R(vertices)            
            renders = (renders - self.mean) / self.std

            fake_loss = self.adversarial_loss(
                self.D(renders.detach()), fake)

            # discriminator loss is the average of these
            d_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

    
    def configure_optimizers(self):
        lr_g = self.hparams.lr_g
        lr_d = self.hparams.lr_d
        b1 = self.hparams.beta1
        b2 = self.hparams.beta2      
        opt_gs = torch.optim.Adam(list(self.G.parameters()) 
                                 + list(self.S.parameters()), 
                                 lr=lr_g, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.D.parameters(), 
                                 lr=lr_d, betas=(b1, b2))
        return [opt_gs, opt_d], []
    
config = get_parser().parse_args(args=[])    
model = GAN(config)
model

GAN(
  (G): Generator(
    (head): ModConvLayer(
      (conv): EqualizedModConv2d(6, 32, 3, upsample=False, downsample=False)
    )
    (body): Sequential(
      (block1): ConvBlock(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (block2): ConvBlock(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (block3): ConvBlock(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    

In [3]:
dm = MaskedDataModule(config)
dm

<src.data.masked_datamodule.MaskedDataModule at 0x7fc28a9b18e0>

In [4]:
trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores

  | Name | Type          | Params
---------------------------------------
0 | G    | Generator     | 31.3 K
1 | D    | Discriminator | 83.2 K
2 | S    | Stylist       | 9.3 M 
3 | R    | Renderer      | 0     
---------------------------------------
9.4 M     Trainable params
0         Non-trainable params
9.4 M     Total params
37.620    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

points.shape, normals.shape, tm.shape torch.Size([4, 65536, 3]) torch.Size([4, 1000000, 3]) torch.Size([4, 3])



RuntimeError: CUDA error: device-side assert triggered

In [3]:
dm = MaskedDataModule()
model = GAN(*dm.size())
trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)
trainer.fit(model, dm)

False

In [None]:
img_dir ='/home/bobi/Desktop/db/ffhq-dataset/images1024x1024'
mask_dir = '/home/bobi/Desktop/face-parsing.PyTorch/res/masks'

ds = MaskedDataset(img_dir, mask_dir)    
ds[0]

In [4]:
import torch

In [6]:
torch.ones((2, 1))

tensor([[1.],
        [1.]])

In [2]:
import pytorch3d.transforms as T3


In [6]:
T3.Translate(torch.rand(1, 4, 3))

TypeError: must be real number, not NoneType

In [3]:
device = torch.device("cuda")
points = torch.rand((4, 1000, 3), device=device)
T = T3.Translate(-points.mean(dim=-2, keepdim=False), device=points.device)#.to(points.device)

In [4]:
points = torch.rand((4, 1000, 3), device=device)
T = T3.Translate(-points.mean(dim=-2, keepdim=False)).to(points.device)

In [5]:
points = torch.rand((4, 1000, 3), device=device)
T = T3.Translate(-points.mean(dim=-2, keepdim=False)).type_as(points)

AttributeError: 'Translate' object has no attribute 'type_as'