In [1]:
import torch

# 假设 x 是一个输入张量
x = torch.randn(1, 3, 16, 16)  # 3D 输入数据，形状为 mini-batch x channels x height x width

# 使用双线性插值法将 x 上采样到指定的尺寸
upsampled_x = torch.nn.functional.interpolate(x, size=(32, 32), mode='bilinear', align_corners=False)
upsampled_x.shape


torch.Size([1, 3, 32, 32])

In [3]:
from torch import nn


class Upsample(nn.Module):
    def __init__(self, input_feat,out_feat):
        super(Upsample, self).__init__()

        self.body = nn.Sequential(  # nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
            # dw
            nn.Conv2d(input_feat, input_feat, kernel_size=3, stride=1, padding=1, groups=input_feat, bias=False, ),
            # pw-linear
            nn.Conv2d(input_feat, out_feat * 4, 1, 1, 0, bias=False),
            # nn.BatchNorm2d(n_feat*2),
            # nn.Hardswish(),
            nn.PixelShuffle(2))

    def forward(self, x):
        return self.body(x)


In [7]:
y = torch.randn(1, 16, 8, 8)


In [8]:
upsample = Upsample(16, 32)
result = upsample(y)
result.shape

torch.Size([1, 32, 16, 16])

In [9]:
class Downsample(nn.Module):
    def __init__(self, input_feat,out_feat):
        super(Downsample, self).__init__()

        self.body = nn.Sequential(  # nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False),
            # dw
            nn.Conv2d(input_feat, input_feat, kernel_size=3, stride=1, padding=1, groups=input_feat, bias=False, ),
            # pw-linear
            nn.Conv2d(input_feat, out_feat // 4, 1, 1, 0, bias=False),
            # nn.BatchNorm2d(n_feat // 2),
            # nn.Hardswish(),
            nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)


In [10]:
downsample = Downsample(16, 8)
result_2 = downsample(y)
result_2.shape

torch.Size([1, 8, 4, 4])

In [11]:
depths=[2, 2, 18, 2]
dpr = [
    x.item() for x in torch.linspace(0, 0.5, sum(depths))
] 
dpr

[0.0,
 0.021739130839705467,
 0.043478261679410934,
 0.06521739065647125,
 0.08695652335882187,
 0.10869565606117249,
 0.1304347813129425,
 0.15217392146587372,
 0.17391304671764374,
 0.19565217196941376,
 0.21739131212234497,
 0.239130437374115,
 0.260869562625885,
 0.28260868787765503,
 0.30434781312942505,
 0.32608693838119507,
 0.3478260636329651,
 0.3695652186870575,
 0.3913043439388275,
 0.41304346919059753,
 0.43478262424468994,
 0.45652174949645996,
 0.47826087474823,
 0.5]

In [12]:
sum(depths[:0])

0

In [13]:
sum(depths[:1])

2

In [14]:
dpr[0:2]

[0.0, 0.021739130839705467]

In [15]:
class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()

        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

In [67]:
x = torch.randn(1, 512, 32, 32)
conv1 = BasicConv2d(512, 256, 1, stride=1, padding=0, dilation=1)
x = conv1(x)
x.shape

torch.Size([1, 256, 32, 32])

In [68]:
up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

y = up(x)
y.shape


torch.Size([1, 256, 64, 64])

In [60]:
y

tensor([[[[ 0.0172,  0.0172,  0.0172,  ...,  0.0172,  0.0172,  0.0172],
          [ 0.0172,  0.0172,  0.0172,  ...,  0.0172,  0.0172,  0.0172],
          [ 0.0172,  0.0172,  0.0172,  ...,  0.0172,  0.0172,  0.0172],
          ...,
          [ 0.0172,  0.0172,  0.0172,  ...,  0.0172,  0.0172,  0.0172],
          [ 0.0172,  0.0172,  0.0172,  ...,  0.0172,  0.0172,  0.0172],
          [ 0.0172,  0.0172,  0.0172,  ...,  0.0172,  0.0172,  0.0172]],

         [[-0.0098, -0.0098, -0.0098,  ..., -0.0098, -0.0098, -0.0098],
          [-0.0098, -0.0098, -0.0098,  ..., -0.0098, -0.0098, -0.0098],
          [-0.0098, -0.0098, -0.0098,  ..., -0.0098, -0.0098, -0.0098],
          ...,
          [-0.0098, -0.0098, -0.0098,  ..., -0.0098, -0.0098, -0.0098],
          [-0.0098, -0.0098, -0.0098,  ..., -0.0098, -0.0098, -0.0098],
          [-0.0098, -0.0098, -0.0098,  ..., -0.0098, -0.0098, -0.0098]],

         [[ 0.0138,  0.0138,  0.0138,  ...,  0.0138,  0.0138,  0.0138],
          [ 0.0138,  0.0138,  