In [None]:
from init_notebook import *

In [None]:
bs = 2
shape = (32, 24)
feature = torch.rand(bs, 4)
grid_x, grid_y = torch.meshgrid(torch.linspace(1, 2, 10), torch.linspace(1, 2, 10), indexing="xy")
grid_x

In [None]:
torch.meshgrid?

In [None]:
class ConvLayer(nn.Module):

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            condition_size: int,
            stride: int = 1,
            padding: int = 0,
            groups: int = 1,
            batch_norm: bool = True,
            batch_norm_pos: int = 0,
            activation: Union[None, str, Callable] = "gelu",
            padding_mode: str = "zeros",
            transposed: bool = False,
    ):
        super().__init__()
        self._batch_norm_pos = batch_norm_pos
        self._condition_size = condition_size

        if batch_norm and batch_norm_pos == 0:
            self.bn = nn.BatchNorm2d(in_channels)

        self.conv = (nn.ConvTranspose2d if transposed else nn.Conv2d)(
            in_channels=in_channels + self._condition_size,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            padding_mode=padding_mode,
            groups=groups,
        )

        if batch_norm and batch_norm_pos == 1:
            self.bn = nn.BatchNorm2d(out_channels)

        self.act = activation_to_module(activation)

        if batch_norm and batch_norm_pos == 2:
            self.bn = nn.BatchNorm2d(out_channels)

    def forward(
            self,
            x: torch.Tensor,
            condition: Optional[torch.Tensor] = None,
            output_size: Union[None, Tuple[int, int]] = None,
    ) -> torch.Tensor:
        if self._batch_norm_pos == 0 and hasattr(self, "bn"):
            x = self.bn(x)

        B, C, H, W = x.shape
        if condition is None:
            condition_map = torch.zeros(B, self._condition_size, H, W)
        else:
            # B, C -> B, C, H, W 
            condition_map = condition[:, :, None, None].expand(-1, -1, H, W)
            
        x = torch.concat([x, condition_map.to(x)], dim=-3)
        
        x = self.conv(x)

        if output_size is not None and tuple(x.shape[-2:]) != output_size:
            x = F.pad(x, (0, output_size[-1] - x.shape[-1], 0, output_size[-2] - x.shape[-2]))

        if self._batch_norm_pos == 1 and hasattr(self, "bn"):
            x = self.bn(x)

        if self.act:
            x = self.act(x)

        if self._batch_norm_pos == 2 and hasattr(self, "bn"):
            x = self.bn(x)

        return x

layer = ConditionalConvLayer(3, 3, 3, 5).eval()
output = layer(torch.zeros(2, 3, 10, 10), torch.randn(2, 5))
print(output.shape)
VF.to_pil_image(make_grid(output))