In [27]:
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

from src.models.util import ConvBlock
from src.utilities.operators import mean_distance
from types import SimpleNamespace

class SinGenerator(nn.Module):
    def __init__(self, channels):
        super(SinGenerator,self).__init__()
        self.trunk = nn.Sequential(OrderedDict([
            ('head', ConvBlock(3, channels[0])),
            ('main', nn.Sequential(OrderedDict([
                ('b'+str(i), ConvBlock(in_ch, out_ch))
                for i, (in_ch, out_ch) in 
                enumerate(zip(channels, channels[1:]))])))
        ]))
        self.points = nn.Sequential(
            spectral_norm(nn.Conv2d(channels[-1], 1, 3, 1, 1, bias=False)),
            nn.Sigmoid(),)      

    def scale(self, t, size):
        return F.interpolate(t, size=size, mode='bilinear', align_corners=True)

    def forward(self, outline):
        trunk = self.trunk(outline)
        points = self.points(trunk)        
        return points

    
class Generator(nn.Module):
    def __init__(self, config):
        super(Generator, self).__init__()                      
        channels =  config.fast_generator_channels 
        self.points = SinGenerator(channels)
        #self.colors = SkipGenerator(channels)
    
    def get_means(self, t):        
        means = t.detach().reshape(t.size(0), -1).mean(1)
        return means.reshape(-1, 1, 1, 1)
        
    def forward(self, points, normals):
        #print(baseline.size(-1),  outline.size(-1))       
        dist = mean_distance(pts) / 2
        magnitudes = self.points(points)        
        magnitudes = magnitudes - self.get_means(magnitudes)
        magnitudes = magnitudes * dist.reshape(-1, 1, 1, 1)           
        res = points + normals * magnitudes
        return points, torch.ones_like(points)


sn = SimpleNamespace(fast_generator_channels=[ 128, 128, 256, 256])

G = Generator(sn)
G

Generator(
  (points): SinGenerator(
    (trunk): Sequential(
      (head): ConvBlock(
        (conv): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (lrelu): LeakyReLU(negative_slope=0.2)
      )
      (main): Sequential(
        (b0): ConvBlock(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (lrelu): LeakyReLU(negative_slope=0.2)
        )
        (b1): ConvBlock(
          (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (lrelu): LeakyReLU(negative_slope=0.2)
        )
        (b2): ConvBlock(
          (conv): Conv2d(256, 256, kernel_size=(3, 3),

In [28]:
pts = torch.rand(2, 3, 4, 4)
pts.shape

torch.Size([2, 3, 4, 4])

In [30]:
vrt, _ = G(pts, pts)
vrt.shape

torch.Size([2, 3, 4, 4])

In [24]:
dist = mean_distance(pts)
dist

tensor([0.6706, 0.7254])

In [25]:
means = vrt.detach().reshape(pts.size(0), -1).mean(1)
means, means.shape

(tensor([0.5016, 0.5575]), torch.Size([2]))

In [26]:
vrt - means.reshape(-1, 1, 1, 1)

tensor([[[[ 0.1090, -0.0768,  0.2078,  0.0793],
          [ 0.0754, -0.0079,  0.0776, -0.0473],
          [-0.0315,  0.0518, -0.1015,  0.0277],
          [-0.1717, -0.0138, -0.1900,  0.0120]]],


        [[[-0.0134,  0.0399,  0.0886,  0.1228],
          [ 0.0176,  0.0110, -0.0573, -0.0844],
          [-0.0429,  0.1672,  0.1582,  0.0219],
          [-0.0550, -0.1100, -0.1240, -0.1403]]]], grad_fn=<SubBackward0>)