# Padding Analysis of Downsampling & Upsampling

In [None]:
"""
Imports so notebook looks nice
"""

from einops import rearrange
import copy
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from inspect import isfunction
from functools import partial

Downsampling utilizes `CylindricalConv`, while Upsampling uses `CylindricalConvTrans`. Our tensor is formatted as `(batch_size, channels, z_bin, phi_bin, r_bin)`, and both convolutions utilize the `torch.nn` convolution functions to actually manipulate tensor size. However, artificial padding is added beforehand in the phi_bin dimension in order to mimic a cylindrical shape. The equation given in the comments for each convolution doesn't seem to be correct. The goal of this analysis is to see if we can modify this artificial padding so that when we stride (compress) by a factor that isn't 2, the code doesn't error as the artificial padding is currently hard coded for this stride.

# Downsampling & CylindricalConv

In [None]:
"""
Downsampling code snippet:

compress_Z = True usually, compress = 2 by default
"""
# Z_stride = compress if compress_Z else 1
# if cylindrical:
#     return CylindricalConv(
#         dim,
#         dim,
#         kernel_size=(3, 4, 4),
#         stride=(Z_stride, compress, compress),
#         padding=1,
#     )

In [None]:
class CylindricalConv(nn.Module):
    """
    Cylindrical 3D Convolution layer.

    Assumes input tensor format: (batch_size, channels, z_bin, phi_bin, r_bin)
    
    All tensor size changes related to padding & stride only affect z_bin, phi_bin, r_bin.
    Padding and stride are all lists of length 3, with each index corresponding to the respective
    bin above.
    """

    def __init__(
            self, dim_in, dim_out, kernel_size=3, stride=1, groups=1, padding=0, bias=True
    ):
        super().__init__()
        # Adjust padding for circular dimension
        
        # Makes padding [1, 1, 1]
        
        if isinstance(padding, int):
            padding = [padding] * 3
        else:
            self.padding_orig = copy.copy(padding)
            padding = list(padding)
            padding[1] = 0  # No padding for phi_bin dimension; will pad manually

        self.kernel_size = kernel_size
        self.padding_orig = copy.copy(padding)
        padding[1] = 0  # Remove padding for phi_bin dimension
        
        # padding = [1, 0, 1] for nn.Conv3d
        
        self.conv = nn.Conv3d(
            dim_in,
            dim_out,
            kernel_size=kernel_size,
            stride=stride,
            groups=groups,
            padding=padding,
            bias=bias,
        )

    def forward(self, x):
        """
        Forward pass for the cylindrical convolution.

        Pads the phi_bin dimension circularly before applying convolution.
        """
        # Circular padding for the phi_bin dimension
        # To achieve 'same' use padding P = ((S-1)*W-S+F)/2, with F = filter size, S = stride, W = input size
        # Pad last dim with nothing, 2nd to last dim is circular one
        
        # padding_orig = [1, 1, 1]
        
        circ_pad = self.padding_orig[1]
        x = F.pad(x, pad=(0, 0, circ_pad, circ_pad, 0, 0), mode="circular")
        
        # x tensor size is now (batch_size, channels, z_bin, phi_bin + 2, r_bin)
        
        x = self.conv(x)
        
        # x tensor size is actually downsampled in nn.Conv3d call
        return x

[nn.Conv3d](https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html):
Shrinks tensor (z, phi r) with dimension index $i$ from $d_{\text{i_in}}$ to $d_{\text{i_out}}$ following:
$$
d_{\text{i_out}} = \frac{d_{\text{i_in}} + 2 \times \text{padding}_i - \text{dilation}_i \times (\text{kernel_size}_i -1) - 1}{\text{stride}_i} + 1
$$

In this case, $\text{dilation}_i = 1$ so we can simplify the equation to:
$$
d_{\text{i_out}} = \frac{d_{\text{i_in}} + 2 \times \text{padding}_i - \text{kernel_size}_i}{\text{stride}_i} + 1
$$


We know `padding = [1, 0, 1]` and `kernel_size = (3, 4, 4)`. Looking at the phi_bin dimension (index 1), $\text{padding}_i = 0$ and $\text{kernel_size}_i = 4$. For equation simplicity, let $\text{kernel_size}_i = k$ and $\text{stride} = s$. Let $d_{\text{i_in}} = \text{phi_bin} + \text{circ_pad} \times 2$, or more simply $d_i + p \times 2$. Ideally as well, $d_{\text{i_out}} = d_i / s$.

Thus, our equation is now:
$$
\frac{d_i}{s} = \frac{d_i + p \times 2 - k}{s} + 1
$$

Example: initial phi_bin dimension is 10 with artificial padding 1 and stride 2:
$$
\frac{10}{2} = \frac{10 + 2 - 4}{2} + 1
$$
$$
5 = 5
$$

To obtain a variational value for $\text{circ_pad} = p$ to appropriately modify our tensor for a variational stride $s$, we solve for $p$:
$$
d_i = d_i + p \times 2 - k + s
$$
$$
p = \frac{k - s}{2}
$$

This is our ideal value for the artificial padding of the phi dimension if we were to use a stride != 2. However, this raises issues for instance, in a stride of 1, where padding becomes 1.5. Only integers are allowed for the `functional.pad` function. Knowing kernel size is hard coded to 4 restricts our options a LOT, and also opens the question if we should be modifying more than just the artificial padding if we want to implement variational stride.

# Upsampling & CylindriclConvTrans

In [None]:
"""
Upsampling code snippet:

compress_Z = True usually, compress = 2 by default

For simplicity sake, consider extra_upsample = 0 so no output_padding
"""
# Z_stride = compress if compress_Z else 1
# Z_kernel = 4 if extra_upsample[0] > 0 else 3
# 
# extra_upsample[0] = 0  # Ensure Z-dimension extra upsample is zero
# if cylindrical:
#     return CylindricalConvTrans(
#         dim,
#         dim,
#         kernel_size=(Z_kernel, 4, 4),
#         stride=(Z_stride, compress, compress),
#         padding=1,
#         output_padding=extra_upsample,
#     )

In [None]:
class CylindricalConvTrans(nn.Module):
    """
    Cylindrical 3D Transposed Convolution layer.

    Assumes input tensor format: (batch_size, channels, z_bin, phi_bin, r_bin)
    """

    def __init__(
            self,
            dim_in,
            dim_out,
            kernel_size=(3, 4, 4),
            stride=(1, 2, 2),
            groups=1,
            padding=1,
            output_padding=0,
    ):
        super().__init__()
        # Adjust padding for circular dimension

        # Makes padding [1, 1, 1]
        
        if not isinstance(padding, int):
            self.padding_orig = copy.copy(padding)
            padding = list(padding)
        else:
            padding = [padding] * 3
            self.padding_orig = copy.copy(padding)

        padding[1] = kernel_size[1] - 1  # Adjust padding for phi_bin dimension

        # padding = [1, 3, 1] for nn.ConvTranspose3d
        
        self.convTrans = nn.ConvTranspose3d(
            dim_in,
            dim_out,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
        )

    def forward(self, x):
        """
        Forward pass for the cylindrical transposed convolution.

        Pads the phi_bin dimension circularly before applying convolution.
        """
        # Circular padding for the phi_bin dimension
        # Out size is : O = (i-1)*S + K - 2P
        # To achieve 'same' use padding P = ((S-1)*W-S+F)/2, with F = filter size, S = stride, W = input size
        # Pad last dim with nothing, 2nd to last dim is circular one
        
        # padding_orig = [1, 1, 1]

        circ_pad = self.padding_orig[1]
        x = F.pad(x, pad=(0, 0, circ_pad, circ_pad, 0, 0), mode="circular")

        # x tensor size is now (batch_size, channels, z_bin, phi_bin + 2, r_bin)

        x = self.convTrans(x)

        # x tensor size is actually upsampled in nn.ConvTranspose3d call
        return x

[nn.ConvTranspose3d](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html):
Increases tensor (z, phi r) with dimension index $i$ from $d_{\text{i_in}}$ to $d_{\text{i_out}}$ following:
$$
d_{\text{i_out}} = (d_{\text{i_in}} - 1) \times \text{stride}_i - 2 \times \text{padding}_i + \text{dilation}_i \times (\text{kernel_size}_i - 1) + \text{output_padding}_i + 1
$$

Note that the equation for nn.ConvTranspose3d and nn.Conv3d are mathematically equal except $d_{\text{i_in}}$ and $d_{\text{i_out}}$ are swapped, and $\text{output_padding}_i$ is added on at the end.

$$
d_{\text{i_in}} = \frac{d_{\text{i_out}} + 2 \times \text{padding}_i - \text{dilation}_i \times (\text{kernel_size}_i -1) - 1}{\text{stride}_i} + 1
$$
$$
(d_{\text{i_in}} - 1) \times \text{stride}_i = d_{\text{i_out}} + 2 \times \text{padding}_i - \text{dilation}_i \times (\text{kernel_size}_i -1) - 1
$$
$$
d_{\text{i_out}} = (d_{\text{i_in}} - 1) \times \text{stride}_i - 2 \times \text{padding}_i + \text{dilation}_i \times (\text{kernel_size}_i -1) + (\text{output_padding}_i) + 1
$$

Again, $\text{dilation}_i = 1$, and let $\text{kernel_size}_i = k$, $\text{stride} = s$, and as aforementioned $\text{output_padding} = 0$ so we can simplify the equation to: 
$$
d_{\text{i_out}} = (d_{\text{i_in}} - 1) \times s - 2 \times \text{padding}_i + k
$$

Example: our compressed phi_bin dimension is 5, expected 10 after upsample with stride 2. We know $\text{padding}_i = 3$ and $k = 4$. Our $d_\text{i_in} = 7$ due to artificial padding of 1 on each side (5 + 1 + 1):
$$
10 = (7 - 1) \times 2 - 2 \times 3 + 4
$$
$$
10 = 10
$$

Again, $d_\text{i_in} = d_i + p \times 2$, and ideally $d_\text{i_out} = d_i \times s$, and so for our artificial padding $p$:
$$
d_i \times s = (d_i + p \times 2 - 1) \times s - 2 \times \text{padding}_i + k
$$
$$
d_i \times s = d_i \times s + (p \times 2 - 1) \times s - 2 \times \text{padding}_i + k
$$
$$
2 \times \text{padding}_i - k = (p \times 2 - 1) \times s
$$
$$
\frac{2 \times \text{padding}_i - k}{s} = p \times 2 - 1
$$
$$
p = \frac{\frac{2 \times \text{padding}_i - k}{s} + 1}{2}
$$

We take away the same issues here as we did from the downsampling, as p can only be an integer. Additionally, we need to ensure that the upsampled dimension matches the original input dimension across multiple operations. Seeing that $\text{padding}_i$ is hardcoded to 3 for the phi_bin dimension again brings the question if we should modify addiional parameters beyond the artificial padding. A potential solution is just making $\text{padding}_i = 0$ to mimic the downsample nn.Conv3d function, however we would still need to address the integer p issue.

Look to Emily for using both ConvTranspose3d and Conv3d to achieve this?

Compress factor of 2, layer sizes [16, 16, 16] (2 downsamples)  
dim / dim_out refers to layer dimensions  
1 downsample = ResnetBlock1 > ResnetBlock2 > Attention > Downsample (nn.Identity if last)
1 Resnetblock = block1 > block2 > res_conv (CylindricalConv(dim, dim_out, kernel_size=1) only if dim change otherwise nn.Identity())  
1 Attention = Residual(PreNorm(dim_out, LinearAttention(dim_out, cylindrical=cylindrical)  
1 Linear Attention = to_qkv (CylindricalConv(dim, hidden_dim*3, kernel_size=1, bias=False) > to_out (Sequential(CylindricalConv(hidden_dim, dim, kernel_size=1) > GroupNorm)  
1 block = proj (CylindricalConv(dim, dim_out, kernel_size=3, padding=1) > norm (GroupNorm) > activation (SiLu)  

In this case:  
block 1 > block 2 > nn.Identity  
```
Initial Conv: torch.Size([128, 16, 5, 10, 30])

DOWNSAMPLE RESNET BLOCK START

ResnetBlock1
    Block 1
    CylindricalConv forward pad: torch.Size([128, 16, 5, 12, 30])
    CylindricalConv internal conv: torch.Size([128, 16, 5, 10, 30])

    Block 2
    CylindricalConv forward pad: torch.Size([128, 16, 5, 12, 30])
    CylindricalConv internal conv: torch.Size([128, 16, 5, 10, 30])

    Identity
    torch.Size([128, 16, 5, 10, 30])

ResnetBlock2
    Block 1
    CylindricalConv forward pad: torch.Size([128, 16, 5, 12, 30])
    CylindricalConv internal conv: torch.Size([128, 16, 5, 10, 30])

    Block 2
    CylindricalConv forward pad: torch.Size([128, 16, 5, 12, 30])
    CylindricalConv internal conv: torch.Size([128, 16, 5, 10, 30])

    Identity
    torch.Size([128, 16, 5, 10, 30])

Linear Attention
    to_qkv    
    CylindricalConv forward pad: torch.Size([128, 16, 5, 10, 30])
    CylindricalConv internal conv: torch.Size([128, 96, 5, 10, 30])

    to_out
    CylindricalConv forward pad: torch.Size([128, 32, 5, 10, 30])
    CylindricalConv internal conv: torch.Size([128, 16, 5, 10, 30])

Downsample
CylindricalConv forward pad: torch.Size([128, 16, 5, 12, 30])
CylindricalConv internal conv: torch.Size([128, 16, 3, 5, 15])

Result Downsample 1: torch.Size([128, 16, 3, 5, 15])
```