In [3]:
import torch
from torch import nn

- 使用nn.Upsample时需要注意
- 只有trilinear模式，才能处理5D tensor
- 使用时，有如下warning：
/home/liguanlin/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py:2506: UserWarning: Default upsampling behavior when mode=trilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
- 其结果是会将width，depth和height都增加1倍，也就是变成原来的二倍

In [5]:
upsample = nn.Upsample(scale_factor=2, mode='trilinear',align_corners=True) #该函数可以处理3D tensor

net_input = torch.randn(8, 4, 16, 160, 160)
res = upsample(net_input) #torch.Size([8, 4, 32, 320, 320])
print(res.shape)

net_input = torch.randn(8, 128, 8, 80, 80)
res1 = upsample(net_input)
print(res1.shape) #torch.Size([8, 128, 16, 160, 160])

torch.Size([8, 4, 32, 320, 320])
torch.Size([8, 128, 16, 160, 160])


- nn.ConvTranspose3d() 模块的使用

- https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html

- Applies a 3D transposed convolution operator over an input image composed of several input planes. The transposed convolution operator multiplies each input value element-wise by a learnable kernel, and sums over the outputs from all input feature planes.

- This module can be seen as the gradient of Conv3d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation).

- Input：(N, Cin, Din, Hin, Win)
- Output: (N, Cout, Dout, Hout, Wout)

- 其中Din 和 Dout 分别表示depth的输入和输出
- Dout, Hout, Wout的计算公式是一致的，可以参考：https://zhuanlan.zhihu.com/p/343827706
- Dout = (Din - 1)*stride + kernel_size - 2*padding


In [7]:
m = nn.ConvTranspose3d(16, 33, 3, stride=2)

input = torch.randn(20, 16, 10, 50, 100)
output = m(input)
print(output.shape)#torch.Size([20, 33, 21, 101, 201])

m_complex = nn.ConvTranspose3d(16, 33, kernel_size=(3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2))
output = m_complex(input)
print(output.shape) #torch.Size([20, 33, 21, 46, 97])


torch.Size([20, 33, 21, 101, 201])
torch.Size([20, 33, 21, 46, 97])


- 如何使用nn.ConvTranspose3d将depth，width，height都增加一倍，即变成原来的二倍
- nn.ConvTranspose3d(16, 33, kernel_size=(2, 2, 2), stride=(2, 2, 2))

In [8]:
m_double = nn.ConvTranspose3d(16, 33, kernel_size=(2, 2, 2), stride=(2, 2, 2))
input = torch.randn(20, 16, 10, 50, 100)

output = m_double(input)
print(output.shape)

torch.Size([20, 33, 20, 100, 200])
