In [46]:
import torch
from torchvision.models import resnet18
from collections import OrderedDict
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [47]:
class CBA(nn.Module):
    def __init__(self, inchannel, outchannel, kernelsize, padding, bias):
        super(CBA, self).__init__()
        
        self.conv = nn.Conv2d(inchannel, outchannel, kernelsize, padding=padding)
        self.bn = nn.BatchNorm2d(outchannel)
        self.act = nn.ReLU()
        
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

In [48]:
class UpModule(nn.Module):
    def __init__(self, inchannel, outchannel, kernelsize, padding, bias, scale):
        super(UpModule, self).__init__()
        
        self.up = nn.UpsamplingBilinear2d(scale_factor=scale) #尺度变化
        self.cba = CBA(inchannel, outchannel, kernelsize, padding, bias)#通道数变化
        
    def forward(self, x):
        
        return self.cba(self.up(x))

In [51]:
class Resnet18(nn.Module):
    def __init__(self, heads):
        super(Resnet18, self).__init__()
    
        self.features = resnet18(pretrained=True)
        self.featurs_modules = OrderedDict(list(self.features._modules.items())[:-2])
        #上采样部分
        self.up0 = UpModule(512, 256, 3, padding=1, bias=False, scale=2)
        self.up1 = UpModule(256, 128, 3, padding=1, bias=False, scale=2)
        self.up2 = UpModule(128, 64,  3, padding=1, bias=False, scale=2)
       
        self.heads = heads
        for name, out_channel in heads:
            #检测头设计
            ext_head = nn.Sequential(nn.Conv2d(64, 128, 1),
                                     nn.ReLU())
            setattr(self, f"{name}_ext", ext_head)
            #任务头设计
            out_head = nn.Conv2d(128, out_channel, 1)
            setattr(self, f"{name}", out_head)
            
            if name.find("hm") != -1:
                out_head.bias.data.fill_(-2.14)
    
    def forward(self, x):
        
        keeps = {"layer4": None, "layer3": None, "layer2": None, "layer1": None}
        for index, (name, layer) in enumerate(self.features_modules.items()):
            x = layer(x)
            if name in keep:
                keep[name] = x
                
        up0_4_4 = self.up0(keep["layer4"])
        up1_8_8 = self.up1(keep["layer3"] + up0_4_4)
        up1_16_16 = self.up2(keep["layer2"] + up1_8_8)
        
        x = keeps["layer1"] + up1_16_16
        x = F.relu(x)
        
        output = {}
        for name, out_channel in self.heads:
            ext_head = getattr(self, f"{name}_ext")
            out_head = getattr(self, f"{name}")
            output[name] = out_head(ext_head(x))
            
            if name.find("hm") != -1:
                output[name] = outputp[name].sigmoid()
        return output         

In [52]:
model = Resnet18([("hm", 1), ("xy", 2), ("wh", 2)])

In [53]:
print(model)

Resnet18(
  (features): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_r