In [193]:
import torch.nn as nn
import torch

import numpy as np
import math

In [201]:
def dw_conv(
        in_c: int, out_c: int, kernel_size: int, stride: int, padding: int
    ):
        return nn.Sequential(
            nn.Conv2d(
                in_c,
                in_c,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=in_c,
            ),
            nn.BatchNorm2d(in_c),
            nn.LeakyReLU(0.2),
            nn.Conv2d(
                in_c,
                out_c,
                kernel_size=1,
                stride=1,
                padding=0,
            ),
        )

class HideNet(nn.Modulimport timee):
    def get_channels(self, in_channels=6, out_channels=3, init_channels=64, max_channels=512, num_conv=6):
        # Initialize lists
        encoder_in = [in_channels]  # input channels for encoder
        encoder_out = []  # output channels for encoder
        decoder_in = []  # input channels for decoder
        decoder_out = [max_channels]  # output channels for decoder

        # Build encoder
        for i in range(num_conv):
            encoder_out.append(min(init_channels * 2 ** i, max_channels))
            encoder_in.append(encoder_out[-1])

        # Build decoder
        for i in range(num_conv):
            decoder_in.append(encoder_out[-1 - i] * 2)
            decoder_out.append(min(init_channels * 2 ** (num_conv - i - 2), max_channels))

        # Reverse the decoder lists to match the U-Net architecture
        decoder_in
        decoder_out

        # Adjust input and output channels to match given values
        encoder_in[0] = in_channels
        decoder_out[-1] = out_channels

        encoder_in = encoder_in[:-1]
        decoder_out = decoder_out[1:]

        return encoder_in, encoder_out, decoder_in, decoder_out
    
    def down_block(self, in_c: int, out_c: int, conv: nn.Module=nn.Conv2d, kernel_size: int=4, stride: int=2):
        return nn.Sequential(
            conv(in_c, out_c, kernel_size, stride, 1), nn.BatchNorm2d(out_c), nn.LeakyReLU(0.2)
        )

    def up_block(
        self, in_c: int, out_c: int, conv: nn.Module, act=nn.ReLU, mode: str = "nearest"
    ):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode=mode),
            conv(in_c, out_c, 3, 1, 1),
            nn.BatchNorm2d(out_c),
            act(),
        )

    def __init__(
        self,
        in_c: int = 6,
        out_c: int = 3,
        first_c: int = 64,
        n_depthwise: int = 2,
        upsampling_mode: str = "nearest",
        n_conv: int = 6,
        max_c: int = 512,
    ):
        super().__init__()
        assert 1 <= n_depthwise <= 4, "n_depthwise must be between 1 and 4"
        
        self.down_in, self.down_out, self.up_in, self.up_out = self.get_channels(
            init_channels=first_c, 
            max_channels=max_c,
            num_conv=n_conv,
        )

        down_layers = []
        up_layers = []

        for i in range(len(self.down_in)):
            
            if i < n_depthwise:
                conv = nn.Conv2d
            else:
                conv = dw_conv 
                
            down_layers.append(
                self.down_block(self.down_in[i], self.down_out[i], conv)
            )
             
            
        for i in range(len(self.up_in) - 1):
            
            if i < n_depthwise:
                conv = dw_conv
            else:
                conv = nn.Conv2d
                
            up_layers.append(
                self.up_block(self.up_in[i], self.up_out[i], conv)
            )
            
        up_layers.append(
            self.up_block(self.up_in[-1], self.up_out[-1], nn.Conv2d, act=nn.Sigmoid)
        )
        
        self.down_layers = nn.ModuleList(down_layers)
        self.bottleneck = self.down_block(self.down_out[-1], self.down_out[-1], kernel_size=3, stride=1)
        self.up_layers = nn.ModuleList(up_layers)
        
        

    def forward(self, x):
        down_out = [x]

        for i in range(len(self.down_in)):
            down_out.append(self.down_layers[i](down_out[-1]))

        
        up_out = self.bottleneck(down_out[-1])
        up_out += down_out[-1]

        for i in range(1, len(self.up_in)):
            up_out = self.up_layers[i - 1](torch.concat([up_out, down_out[-i]], dim=1))
            
        up_out = self.up_layers[-1](torch.concat([up_out, down_out[1]], dim=1))

        return up_out

In [212]:
x = torch.randn(1,6,256, 128)
net = HideNet(first_c=16, max_c=4096, n_conv=3)

In [213]:
out = net(x)
out.shape

torch.Size([1, 3, 256, 128])

In [209]:
math.log(128, 2)

7.0

In [146]:
net

HideNet(
  (down_layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (2): Sequential(
      (0): Sequential(
        (0): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), groups=128)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
        (3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
      )
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
   

In [60]:
def get_channels(in_c, first_c, max_c):
    down_in = [in_c]
    down_out = [first_c]
    for _ in range(6):
              # Reverse the decoder lists to match the U-Net architecture
        decoder_in
        decoder_out  down_in.append(min(first_c, max_c))
        down_out.append(min(2 * first_c, max_c))
        first_c *= 2
        
    up_in = [down_out[6]]
    up_out = [] 
    
    for i in range(6):
        up_in.append(2 * down_out[-i - 2])
        up_out.append(down_in[-i - 1])
    up_out[-1] = 3 
        
    return down_in, down_out, up_in, up_out

In [61]:
get_channels(6,64, 512)

([6, 64, 128, 256, 512, 512, 512],
 [64, 128, 256, 512, 512, 512, 512],
 [512, 1024, 1024, 1024, 512, 256, 128],
 [512, 512, 512, 256, 128, 3])

In [None]:
model_parameters = filter(lambda p: p.requires_grad, dw_1.parameters())
sum([np.prod(p.size()) for p in model_parameters])

In [None]:
conv_3x3 = nn.Conv2d(128, 256, 3)

In [None]:
model_parameters = filter(lambda p: p.requires_grad, conv_3x3.parameters())
sum([np.prod(p.size()) for p in model_parameters])

In [106]:
def get_channels(in_channels=6, out_channels=3, init_channels=64, max_channels=512, num_conv=6):
        # Initialize lists
        encoder_in = [in_channels]  # input channels for encoder
        encoder_out = []  # output channels for encoder
        decoder_in = []  # input channels for decoder
        decoder_out = [max_channels]  # output channels for decoder

        # Build encoder
        for i in range(num_conv):
            encoder_out.append(min(init_channels * 2 ** i, max_channels))
            encoder_in.append(encoder_out[-1])

        # Build decoder
        for i in range(num_conv):
            decoder_in.append(encoder_out[-1 - i] * 2)
            decoder_out.append(min(init_channels * 2 ** (num_conv - i - 2), max_channels))

        # Adjust input and output channels to match given values
        encoder_in[0] = in_channels
        decoder_out[-1] = out_channels

        encoder_in = encoder_in[:-1]
        decoder_out = decoder_out[1:]

        return encoder_in, encoder_out, decoder_in, decoder_out

In [107]:
ein, eou, din, dou = get_channels()

In [108]:
print(ein, 512, din)
print(eou, 512, dou)

[6, 64, 128, 256, 512, 512] 512 [1024, 1024, 1024, 512, 256, 128]
[64, 128, 256, 512, 512, 512] 512 [512, 512, 256, 128, 64, 3]
