# 模型定义及损失函数实现原理 

1. 双线性插值方法定义
2. FCN模型搭建
3. 损失函数原理补充
4. 损失函数计算过程

In [None]:
# encoding: utf-8
import torch 
from torchvision import models
from torch import nn
import torch.nn.functional as F
import numpy as np

In [None]:
def bilinear_kernel(in_channels, out_channels, kernel_size):
    """Define a bilinear kernel according to in channels and out channels.
    Returns:
        return a bilinear filter tensor
    """
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = np.ogrid[:kernel_size, :kernel_size]
    bilinear_filter = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
    weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float32)
    weight[range(in_channels), range(out_channels), :, :] = bilinear_filter
    return torch.from_numpy(weight)

In [None]:
pretrained_net = models.vgg16_bn(pretrained=False)
pretrained_net.features

In [None]:
pretrained_net.features[0]  # vgg16网络第一层

In [None]:
pretrained_net.features[0] # 64个3*3*3的卷积核

In [None]:
pretrained_net.features[:7]  # vgg16网络第1-6层

In [None]:
class FCN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.stage1 = pretrained_net.features[:7]
        self.stage2 = pretrained_net.features[7:14]
        self.stage3 = pretrained_net.features[14:24]
        self.stage4 = pretrained_net.features[24:34]
        self.stage5 = pretrained_net.features[34:]

        self.scores1 = nn.Conv2d(512, num_classes, 1)
        self.scores2 = nn.Conv2d(512, num_classes, 1)
        self.scores3 = nn.Conv2d(128, num_classes, 1)

        self.conv_trans1 = nn.Conv2d(512, 256, 1)
        self.conv_trans2 = nn.Conv2d(256, num_classes, 1)
        # ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, **args)
        self.upsample_8x = nn.ConvTranspose2d(num_classes, num_classes, 16, 8, 4, bias=False)
        self.upsample_8x.weight.data = bilinear_kernel(num_classes, num_classes, 16)
        
        self.upsample_2x_1 = nn.ConvTranspose2d(512, 512, 4, 2, 1, bias=False)
        self.upsample_2x_1.weight.data = bilinear_kernel(512, 512, 4)

        self.upsample_2x_2 = nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False)
        self.upsample_2x_2.weight.data = bilinear_kernel(256, 256, 4)

    def forward(self, x):
        s1 = self.stage1(x)
        s2 = self.stage2(s1)
        s3 = self.stage3(s2)
        s4 = self.stage4(s3)
        s5 = self.stage5(s4)

        scores1 = self.scores1(s5)
        s5 = self.upsample_2x_1(s5)
        add1 = s5 + s4

        scores2 = self.scores2(add1)

        add1 = self.conv_trans1(add1)
        add1 = self.upsample_2x_2(add1)
        add2 = add1 + s3

        output = self.conv_trans2(add2)
        output = self.upsample_8x(output)
        return output


In [None]:
# fake label data
gt = np.random.rand(1, 352, 480)*12
gt = gt.astype(np.int64)
gt = torch.from_numpy(gt)
print(gt)
x = torch.randn(1, 3, 352, 480)
print(x)

In [None]:
net = FCN(12)
y = net(x)
print(y.shape)

out = F.log_softmax(y, dim=1)
print(out.shape)

criterion = nn.NLLLoss()
print(gt.shape)
loss = criterion(out, gt)
loss

In [None]:
loss.item()

**损失函数**

　　损失的回传最好精细到每个像素上，具体的呈现一下数值是怎么来的

**NLLLoss**

　　这个损失函数的计算可以表达为：`loss(input, class) = -input[class]`。举例说明，三分类任务，输入`input=[-1.233, 2.657, 0.534]`， 真实标签类别`class=2`，则`loss=-0.534`，就是在对应类别的输出上取一个负号。实际应用： 常用于多分类任务，但是input在输入NLLLoss()之前，需要对input进行log_softmax函数激活，即将input转换成概率分布的形式，并且取对数。

In [None]:
# fake label data
gt = np.random.rand(1, 2, 3)*2
gt = gt.astype(np.int64)
gt = torch.from_numpy(gt)

x = torch.randn(1, 2, 2, 3)
out = F.log_softmax(x, dim=1)

print(gt)
print('='*40)
print(x)
print('-'*40)
print(out)

　　直白而言是按标签给出的像素点的类别，去每个像素点对应通道上找相应的值算到损失里。为什么可以这样呢，因为如果该点被正确分类，即在这个点的特征向量相对应类别的位置为1，其他位置的值为0，则经log_softmax计算后，该位置的值为0，即正确分类的损失为0

In [None]:
# gt = tensor([[[0, 1, 1],   
#         [0, 0, 0]]]) 
# gt:dim(batch, w, h)

# out = tensor([[[[-0.2070, -1.0661, -0.6972],   
#          [-0.1605, -0.6022, -0.4681]],
#
#          [[-1.6767, -0.4221, -0.6891],
#          [-1.9085, -0.7933, -0.9839]]]])
# out:dim(Batch, channel, w, h)

criterion = nn.NLLLoss(reduction='none') # default reduction='mean'
loss = criterion(out, gt)
loss

# loss = tensor([[[0.2070, 0.4221, 0.6891],
#          [0.1605, 0.6022, 0.4681]]])
# loss:dim(batch, w, h)

# loss[0][0][0] = 0.2070 来自 -out[0][i][0][0]  i=gt[0][0][0]=0 
# loss[0][1][0] = 0.1605 来自 -out[0][i][1][0]  i=gt[0][1][0]=0 
# loss[0][0][1] = 0.4221 来自 -out[0][i][0][1]  i=gt[0][0][1]=1 

# (0.2070 + 0.4221 + 0.6891 + 0.1605 + 0.6022 + 0.4681) / 6 = 0.4248

# criterion = nn.NLLLoss() # default reduction='mean'
# loss = criterion(out, gt) = 0.4248


In [None]:
(0.2070 + 0.4221 + 0.6891 + 0.1605 + 0.6022 + 0.4681) / 6

In [None]:
criterion = nn.NLLLoss()
loss = criterion(out, gt)
loss

In [1]:
import socket
import time
MaxBytes=1024*1024
 
server = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
server.settimeout(60)
host = '115.156.213.191'
# host = socket.gethostname()
port = 1111
server.bind((host, port))        # 绑定端口
 
server.listen(8)                      # 监听
try:
    client,addr = server.accept()          # 等待客户端连接
    print(addr," 连接上了")
    while True:
        data = client.recv(MaxBytes)
        if not data:
            print('数据为空，我要退出了')
            break
        localTime = time.asctime( time.localtime(time.time()))
        print(localTime,' 接收到数据字节数:',len(data))
        print(data.decode())
        client.send(data)
except BaseException as e:
    print("出现异常：")
    print(repr(e))
finally:
    server.close()                    # 关闭连接
    print("我已经退出了，后会无期")

出现异常：
KeyboardInterrupt()
我已经退出了，后会无期


In [5]:
import socket
import time
MaxBytes=1024*1024
 
server = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
server.settimeout(60)
host = '115.156.213.191'
port = 1111
server.bind((host, port))        # 绑定端口
 
server.listen(8)                      # 监听

client,addr = server.accept()          # 等待客户端连接
print(addr," 连接上了")

OSError: [Errno 98] Address already in use

In [6]:
data = client.recv(MaxBytes)
print(data)
server.close()

b''


In [3]:
data.decode()

'测试'