In [1]:
import math
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from src.models.layers import ModulateConvBlock
#from src.models.blocks import ConvBlock
from src.utilities.util import grid_to_list

class ConvBlock(nn.Sequential):
    def __init__(self, in_channel, out_ch, ker_size, stride, padding):
        super(ConvBlock,self).__init__()
        self.add_module('conv',nn.Conv2d(in_channel, out_ch, 
                                         kernel_size=ker_size,
                                         stride=stride,
                                         padding=padding)),
        self.add_module('norm',nn.BatchNorm2d(out_ch)),
        #self.add_module('norm',nn.InstanceNorm2d(out_ch)),
        self.add_module('LeakyRelu',nn.LeakyReLU(0.2, inplace=True))
        #self.add_module('Nonlinearity',nn.Tanh())
        #self.add_module('GELU',nn.GELU())
        #self.add_module('Nonlinearity', nn.Hardswish(inplace=True))

#     def weights_init(m):
#         classname = m.__class__.__name__
#         if classname.find('Conv2d') != -1:
#             m.weight.data.normal_(0.0, 0.02)
#         elif classname.find('Norm') != -1:
#             m.weight.data.normal_(1.0, 0.02)
#             m.bias.data.fill_(0) 

class GenBlock(nn.Module):
    def __init__(self, latent_size, in_ch, out_ch, kernel, stride, padding, pool=None):
        super(GenBlock, self).__init__()        
        self.mod_conv =  ModulateConvBlock(latent_size, in_ch, out_ch, kernel)
        #self.mod_conv = ConvBlock(in_ch, out_ch, kernel, stride=stride, padding=padding)
        self.conv = ConvBlock(out_ch, out_ch, kernel, stride=stride, padding=padding)
        self.to_points = nn.Conv2d(out_ch, 3, kernel, stride=stride, padding=padding)
        self.pool = nn.AvgPool2d(kernel_size=pool, stride=pool) if pool else None
    
    def upscale(self, x, scale_factor):
        return F.interpolate(x, scale_factor=scale_factor, mode='bilinear', 
                             align_corners=True)# if scale_factor else x     
    
    def forward(self, x, z, prev_vrt):        
        y = self.mod_conv(x, z)
        y = self.conv(y)
        vrt = self.to_points(self.pool(y)) if self.pool else self.to_points(y)                    
        scale_factor = prev_vrt.size(-1) // vrt.size(-1)        
        if scale_factor > 1:
            vrt = self.upscale(vrt, scale_factor)
        vrt = vrt + prev_vrt
        return x, vrt
        
        

    
class Generator(nn.Module):
    def __init__(self, opt):
        super(Generator, self).__init__()                      
        latent_size, in_ch, out_ch, ker_size, stride, padding= (opt.latent_size,
            opt.G_in_ch, opt.G_out_ch, opt.ker_size, opt.stride, opt.padd_size)
        self.head =  ConvBlock(in_ch, out_ch, ker_size, stride=stride, padding=padding)
        self.pools = [4, 2, None]
        
        self.b1 = GenBlock(latent_size, out_ch, out_ch, ker_size, stride=stride, 
                      padding=padding, pool=self.pools[0])
        self.b2 = GenBlock(latent_size, out_ch, out_ch, ker_size, stride=stride, 
                      padding=padding, pool=self.pools[1])
        self.b3 = GenBlock(latent_size, out_ch, out_ch, ker_size, stride=stride, 
                           padding=padding, pool=self.pools[2])
    
    def forward(self, points, latents):                        
        x = self.head(points)
        
        x, vrt = self.b1(x, latents[0], points)
        x, vrt = self.b2(x, latents[1], vrt)
        x, vrt = self.b3(x, latents[2], vrt)
        
        #vrt = grid_to_list(vrt)
        return vrt
 

from src.config import get_parser

config = get_parser().parse_args(args=[])
G = Generator(config)
print(G)

G(torch.rand(5, 3, 64, 64), [torch.rand(5, config.latent_size) for _ in range(3)]).shape

Generator(
  (head): ConvBlock(
    (conv): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (b1): GenBlock(
    (mod_conv): ModulateConvBlock(
      (style): DenseBlock(
        (fc): Linear(in_features=256, out_features=256, bias=False)
        (activate): Identity()
      )
      (activate): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (conv): ConvBlock(
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (to_points): Conv2d(256, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool): AvgPool2d(kernel_size=4, stride=4, padding=0)
  )
  (b2): GenBlock(
    (mod_conv): ModulateConvBlock(
      (sty

torch.Size([5, 3, 64, 64])

In [2]:
device = torch.device("cuda")
device

G = G.to(device)

In [3]:
from src.data.masked_datamodule import MaskedDataModule


config.data_patch_size = 16
config.batch_size = 8
config.num_workers = 6
dm = MaskedDataModule(config)
dm.setup()
dm

torch.Size([16, 3, 512, 512])


<src.data.masked_datamodule.MaskedDataModule at 0x7f0b9c0b0760>

In [4]:
optimizer = torch.optim.Adam(G.parameters())
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)

In [5]:
 for idx, batch in enumerate(iter(dm.train_dataloader())):
        optimizer.zero_grad()
        
        points_coarse = batch['points_coarse'].to(device)#batch['points']             
        points_fine = batch['points'].to(device)
        bs = points_fine.size(0)        
        print(points_fine.shape)
        # train generator
        with torch.autograd.set_detect_anomaly(True):            
            points_noise = points_coarse + torch.randn_like(points_coarse) * config.G_noise_amp
            z_size = config.latent_size
            zero_styles = [torch.zeros(bs, z_size, device=points_coarse.device) for _ in range(3)]
            vertices = G(points_noise, zero_styles)
            print(idx, vertices.shape)
            #break
            #print(vertices.shape)
            vrt_loss = F.mse_loss(vertices, points_noise.reshape_as(vertices))
            
            points_fine = points_fine.reshape_as(vertices)
            vrt_loss = F.l1_loss(vertices, points_fine)
            
            vrt_loss.backward()
            optimizer.step()
            print(idx, vrt_loss)

torch.Size([8, 3, 16, 16])
0 torch.Size([8, 3, 16, 16])
0 tensor(0.4121, device='cuda:0', grad_fn=<L1LossBackward>)
torch.Size([8, 3, 16, 16])
1 torch.Size([8, 3, 16, 16])
1 tensor(4.7421, device='cuda:0', grad_fn=<L1LossBackward>)
torch.Size([8, 3, 16, 16])
2 torch.Size([8, 3, 16, 16])
2 tensor(3.7910, device='cuda:0', grad_fn=<L1LossBackward>)
torch.Size([8, 3, 16, 16])
3 torch.Size([8, 3, 16, 16])
3 tensor(2.0477, device='cuda:0', grad_fn=<L1LossBackward>)
torch.Size([8, 3, 16, 16])
4 torch.Size([8, 3, 16, 16])
4 tensor(1.2277, device='cuda:0', grad_fn=<L1LossBackward>)
torch.Size([8, 3, 16, 16])
5 torch.Size([8, 3, 16, 16])
5 tensor(1.3469, device='cuda:0', grad_fn=<L1LossBackward>)
torch.Size([8, 3, 16, 16])
6 torch.Size([8, 3, 16, 16])
6 tensor(1.7351, device='cuda:0', grad_fn=<L1LossBackward>)
torch.Size([8, 3, 16, 16])
7 torch.Size([8, 3, 16, 16])
7 tensor(1.6608, device='cuda:0', grad_fn=<L1LossBackward>)
torch.Size([8, 3, 16, 16])
8 torch.Size([8, 3, 16, 16])
8 tensor(1.1330,

KeyboardInterrupt: 