In [1]:
import torch
import torch.nn as nn

convT = nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1)
x = torch.randn(2, 4, 4)  # 输入：1张 3通道 4x4 的特征图
y = convT(x)
print(y.shape) 

torch.Size([1, 3, 7, 7])


In [2]:
nn.ConvTranspose2d.parameters


<function torch.nn.modules.module.Module.parameters(self, recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]>

In [3]:
import torch
import torch.nn as nn

class SimpleDecoder(nn.Module):
    def __init__(self, out_channels=3):
        super(SimpleDecoder, self).__init__()

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1, 64, kernel_size=3, stride=2, padding=1, output_padding=1),  # 4 → 8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 128, kernel_size=3, stride=2, padding=1, output_padding=1), # 8 → 16
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # 16 → 32
            nn.ReLU(),
            nn.ConvTranspose2d(64, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1), # 32 → 64
            nn.Tanh()  # 输出范围 [0,1]，适用于图像 #! maybe use tanh?
            #* 由于数据集有normalize 有负值的实际上
        )

    def forward(self, x):
        if x.dim() == 3:  # 输入是 (B, 4, 4)，加上 channel 维
            x = x.unsqueeze(1)  # → (B, 1, 4, 4)
        return self.decoder(x)

# 使用
x = torch.randn(2, 4, 4)  # batch size = 2
decoder = SimpleDecoder(out_channels=3)
out = decoder(x)

print(out.shape)  # ➜ torch.Size([2, 3, 64, 64])


torch.Size([2, 3, 64, 64])


In [4]:
out

tensor([[[[-0.1978, -0.1805, -0.1967,  ..., -0.1798, -0.1957, -0.1802],
          [-0.1609, -0.1594, -0.1478,  ..., -0.1520, -0.1524, -0.1617],
          [-0.1968, -0.1782, -0.2010,  ..., -0.1810, -0.1996, -0.1817],
          ...,
          [-0.1642, -0.1655, -0.1527,  ..., -0.1596, -0.1534, -0.1619],
          [-0.1977, -0.1794, -0.1967,  ..., -0.1830, -0.1981, -0.1845],
          [-0.1654, -0.1580, -0.1689,  ..., -0.1614, -0.1641, -0.1647]],

         [[-0.1226, -0.1039, -0.1207,  ..., -0.1035, -0.1188, -0.1042],
          [-0.1464, -0.1255, -0.1527,  ..., -0.1263, -0.1505, -0.1296],
          [-0.1221, -0.0884, -0.1330,  ..., -0.0921, -0.1263, -0.1025],
          ...,
          [-0.1491, -0.1216, -0.1583,  ..., -0.1208, -0.1534, -0.1305],
          [-0.1240, -0.0991, -0.1280,  ..., -0.0987, -0.1226, -0.1020],
          [-0.1403, -0.1077, -0.1412,  ..., -0.1087, -0.1405, -0.1297]],

         [[-0.0765, -0.0444, -0.0772,  ..., -0.0456, -0.0762, -0.0657],
          [-0.0892, -0.0490, -