In [27]:
import torch
import torch.nn as nn
import torchvision.models as models

In [28]:
def make_decoder_layer(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1),
        nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
    )
    

class Unet_resnet(nn.Module):
    def __init__(self,resnet_type,pretrained=False):
        super(Unet_resnet,self).__init__()
        if resnet_type == 'resnet18':
            resnet = models.resnet.resnet18(pretrained)
            encoder_out_channels = [64, 128, 256, 512]  
        elif resnet_type == 'resnet34':
            resnet = models.resnet.resnet34(pretrained)
            encoder_out_channels = [64, 128, 256, 512]
        elif resnet_type == 'resnet50':
            resnet = models.resnet.resnet50(pretrained)
            encoder_out_channels = [256, 512, 1024, 2048]
        elif resnet_type == 'resnet101':
            resnet = models.resnet.resnet101(pretrained)
            encoder_out_channels = [256, 512, 1024, 2048]
        elif resnet_type == 'resnet152':
            resnet = models.resnet.resnet152(pretrained)
            encoder_out_channels = [256, 512, 1024, 2048]
        else:
            raise ValueError("unexpected resnet_type")


        self.encoder1 = nn.Sequential(
            nn.Conv2d(1,64,kernel_size=1),
            resnet.layer1
        )
        self.encoder2 = resnet.layer2
        self.encoder3 = resnet.layer3
        self.encoder4 = resnet.layer4
        self.encoder5 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Conv2d(encoder_out_channels[-1],2*encoder_out_channels[-1],kernel_size=3,padding=1),
            nn.Conv2d(2*encoder_out_channels[-1],2*encoder_out_channels[-1],kernel_size=3,padding=1)
        )
        self.up4 = nn.ConvTranspose2d(2*encoder_out_channels[-1],encoder_out_channels[-1],stride=2,kernel_size=2)
        self.decoder4 = make_decoder_layer(2*encoder_out_channels[-1],encoder_out_channels[-1])
        
        self.up3 = nn.ConvTranspose2d(2*encoder_out_channels[-2],encoder_out_channels[-2],stride=2,kernel_size=2)
        self.decoder3 = make_decoder_layer(2*encoder_out_channels[-2],encoder_out_channels[-2])
        
        self.up2 = nn.ConvTranspose2d(2*encoder_out_channels[-3],encoder_out_channels[-3],stride=2,kernel_size=2)
        self.decoder2 = make_decoder_layer(2*encoder_out_channels[-3],encoder_out_channels[-3])
        
        self.up1 = nn.ConvTranspose2d(2*encoder_out_channels[-4],encoder_out_channels[-4],stride=2,kernel_size=2)
        self.decoder1 = make_decoder_layer(2*encoder_out_channels[-4],encoder_out_channels[-4])
        
        self.conv1x1 = nn.Conv2d(encoder_out_channels[-4],1,kernel_size=1)
        
        self.__init_weight()
        
        
    def forward(self,x):
        x = self.encoder1(x)
        skipconnect1 = x
        print("After encoder1, the shape is {}".format(x.shape))
        x = self.encoder2(x)
        skipconnect2 = x
        print("After encoder2, the shape is {}".format(x.shape))
        x = self.encoder3(x)
        skipconnect3 = x
        print("After encoder3, the shape is {}".format(x.shape))
        x = self.encoder4(x)
        skipconnect4 = x
        print("After encoder4, the shape is {}".format(x.shape))
        x = self.encoder5(x)
        print("After encoder5, the shape is {}".format(x.shape))
        x = self.up4(x)
        x = torch.cat((skipconnect4,x),dim=1)
        x = self.decoder4(x)

        x = self.up3(x)
        x = torch.cat((skipconnect3,x),dim=1)
        x = self.decoder3(x)

        x = self.up2(x)
        x = torch.cat((skipconnect2,x),dim=1)
        x = self.decoder2(x)

        x = self.up1(x)
        x = torch.cat((skipconnect1,x),dim=1)
        x = self.decoder1(x)

        x = self.conv1x1(x)
        return x
            
    def __init_weight(self):
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias,0)
            if isinstance(m,nn.ConvTranspose2d): # 暂时用kaiming初始化
                nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias,0)

In [29]:
unet = Unet_resnet('resnet34')
#print(unet)
input = torch.randn(1,1,224,224)
output = unet(input)
print(output.shape)

After encoder1, the shape is torch.Size([1, 64, 224, 224])
After encoder2, the shape is torch.Size([1, 128, 112, 112])
After encoder3, the shape is torch.Size([1, 256, 56, 56])
After encoder4, the shape is torch.Size([1, 512, 28, 28])
After encoder5, the shape is torch.Size([1, 1024, 14, 14])
torch.Size([1, 1, 224, 224])


In [30]:
unet = Unet_resnet('resnet101')
#print(unet)
input = torch.randn(1,1,224,224)
output = unet(input)
print(output.shape)

After encoder1, the shape is torch.Size([1, 256, 224, 224])
After encoder2, the shape is torch.Size([1, 512, 112, 112])
After encoder3, the shape is torch.Size([1, 1024, 56, 56])
After encoder4, the shape is torch.Size([1, 2048, 28, 28])
After encoder5, the shape is torch.Size([1, 4096, 14, 14])
torch.Size([1, 1, 224, 224])
