>### SegNet

In [2]:
import os
import numpy as np
import pandas as pd
import pickle
import glob

# Plotting
import imageio
import PIL
from PIL import ImageDraw, Image
import cv2
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['font.sans-serif']=['Arial Unicode MS'] 
plt.rcParams['axes.unicode_minus']=False 

# Framework
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds
import torch
import torch.nn as nn
from torchsummary import summary
from torchvision import transforms
import torch.nn.functional as F

#### Model

In [119]:
class Conv(nn.Module):
    def __init__(self, in_chans, filters, kernel_size=3, padding=1):
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(in_chans, filters, kernel_size, padding=padding)
        self.bn = nn.BatchNorm2d(filters) 
        self.relu = nn.ReLU(inplace=True)

    def forward(self, inputs):
        x = self.conv(inputs)
        x = self.bn(x)
        x = self.relu(x)
        return x  

class ConvBlock(nn.Module):
    def __init__(self, n_convs, in_chans, filters, pool_size=2, pool_stride=2):
        super(ConvBlock, self).__init__()
        block = []
        for i in range(n_convs): 
            block.append(Conv(in_chans, filters)) 
            in_ch = filters
        block.append(nn.MaxPool2d(pool_size, stride=pool_stride, return_indices=True))
        self.block = nn.Sequential(*block)
        
    def forward(self, inputs):
        x, index = self.block(inputs) 
        return x, index
    
class UpBlock(nn.Module): 
    def __init__(self, n_convs, in_chans, filters, pool_size=2, pool_stride=2):
        super(UpBlock, self).__init__()
        self.unpool = nn.MaxUnpool2d(pool_size, stride=pool_stride)
        block = []
        for i in range(n_convs-1): 
            block.append(Conv(in_chans, in_chans)) 
        block.append(Conv(in_chans, filters)) 
        self.block = nn.Sequential(*block)
        
    def forward(self, *inputs): #有多个inputs时
        x, index = inputs
        x = self.unpool(x, index)
        x = self.block(x) 
        return x 

In [120]:
class SegNet(nn.Module):
    def __init__(self, in_chans, n_classes):
        super(SegNet, self).__init__()
        self.block1 = ConvBlock(2, in_chans, 64)
        self.block2 = ConvBlock(2, 64, 128)
        self.block3 = ConvBlock(3, 128, 256)
        self.block4 = ConvBlock(3, 256, 512)
        self.block5 = ConvBlock(3, 512, 512)
        
        self.block5d = UpBlock(3, 512, 512)
        self.block4d = UpBlock(3, 512, 256)
        self.block3d = UpBlock(3, 256, 128)
        self.block2d = UpBlock(2, 128, 64)
        self.block1d = UpBlock(1, 64, 64)
        self.classifier = nn.Conv2d(64, n_classes, kernel_size=3, padding=1)
        
    def forward(self, x):
        x, id1 = self.block1(x)
        x, id2 = self.block2(x)
        x, id3 = self.block3(x)
        x, id4 = self.block4(x)
        x, id5 = self.block5(x)
        
        x = self.block5d(x, id5)
        x = self.block4d(x, id4)
        x = self.block3d(x, id3)
        x = self.block2d(x, id2)
        x = self.block1d(x, id1)
        output = nn.Softmax(dim=1)(self.classifier(x))
        return output

In [121]:
model = SegNet(3, 10)
summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
              Conv-4         [-1, 64, 224, 224]               0
            Conv2d-5         [-1, 64, 224, 224]          36,928
       BatchNorm2d-6         [-1, 64, 224, 224]             128
              ReLU-7         [-1, 64, 224, 224]               0
              Conv-8         [-1, 64, 224, 224]               0
         MaxPool2d-9  [[-1, 64, 112, 112], [-1, 64, 112, 112]]               0
        ConvBlock-10  [[-1, 64, 112, 112], [-1, 64, 112, 112]]               0
           Conv2d-11        [-1, 128, 112, 112]          73,856
      BatchNorm2d-12        [-1, 128, 112, 112]             256
             ReLU-13        [-1, 128, 112, 112]               0
         