In [15]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import models

In [16]:
# 定义双线性插值，作为转置卷积的初始化权重参数
def bilinear_kernel(in_channels, out_channels, kernel_size):
   factor = (kernel_size + 1) // 2
   if kernel_size % 2 == 1:
       center = factor - 1
   else:
       center = factor - 0.5
   og = np.ogrid[:kernel_size, :kernel_size]
   filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
   weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype='float32')
   weight[range(in_channels), range(out_channels), :, :] = filt
   return torch.from_numpy(weight)


class fcn(nn.Module):
   def __init__(self, num_classes):
       super(fcn, self).__init__()
       pretrained_net = models.resnet34(pretrained=True)
       self.stage1 = nn.Sequential(*list(pretrained_net.children())[:-4]) # 第一段
       self.stage2 = list(pretrained_net.children())[-4] # 第二段
       self.stage3 = list(pretrained_net.children())[-3] # 第三段
       
       # 通道统一
       self.scores1 = nn.Conv2d(512, num_classes, 1)
       self.scores2 = nn.Conv2d(256, num_classes, 1)
       self.scores3 = nn.Conv2d(128, num_classes, 1)
       
       # 8倍上采样
       self.upsample_8x = nn.ConvTranspose2d(num_classes, num_classes, 16, 8, 4, bias=False)
       self.upsample_8x.weight.data = bilinear_kernel(num_classes, num_classes, 16) # 使用双线性 kernel
       
       # 2倍上采样
       self.upsample_4x = nn.ConvTranspose2d(num_classes, num_classes, 4, 2, 1, bias=False)
       self.upsample_4x.weight.data = bilinear_kernel(num_classes, num_classes, 4) # 使用双线性 kernel
       self.upsample_2x = nn.ConvTranspose2d(num_classes, num_classes, 4, 2, 1, bias=False)   
       self.upsample_2x.weight.data = bilinear_kernel(num_classes, num_classes, 4) # 使用双线性 kernel

       
   def forward(self, x):
       x = self.stage1(x)
       s1 = x # 1/8
       
       x = self.stage2(x)
       s2 = x # 1/16
       
       x = self.stage3(x)
       s3 = x # 1/32
       
       s3 = self.scores1(s3)
       s3 = self.upsample_2x(s3) # 1/16
       s2 = self.scores2(s2)
       s2 = s2 + s3
       
       s1 = self.scores3(s1)
       s2 = self.upsample_4x(s2) # 1/8
       s = s1 + s2

       s = self.upsample_8x(s) # 1/1
       return s

In [17]:
x = torch.randn(1,3,64,64) # 伪造图像
num_classes = 21
net = fcn(num_classes)
y = net(x)
y.shape

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


torch.Size([1, 21, 64, 64])