In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:


class double_conv2d_bn(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3,strides=1,padding=1):
        super(double_conv2d_bn,self).__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,
                               kernel_size=kernel_size,
                              stride = strides,padding=padding,bias=True)
        self.conv2 = nn.Conv2d(out_channels,out_channels,
                              kernel_size = kernel_size,
                              stride = strides,padding=padding,bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
    
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out
    
class deconv2d_bn(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=2,strides=2):
        super(deconv2d_bn,self).__init__()
        self.conv1 = nn.ConvTranspose2d(in_channels,out_channels,
                                        kernel_size = kernel_size,
                                       stride = strides,bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        return out
    
class Unet(nn.Module):
    def __init__(self):
        super(Unet,self).__init__()
        self.layer1_conv = double_conv2d_bn(1,8)
        self.layer2_conv = double_conv2d_bn(8,16)
        self.layer3_conv = double_conv2d_bn(16,32)
        self.layer4_conv = double_conv2d_bn(32,64)
        self.layer5_conv = double_conv2d_bn(64,128)
        self.layer6_conv = double_conv2d_bn(128,64)
        self.layer7_conv = double_conv2d_bn(64,32)
        self.layer8_conv = double_conv2d_bn(32,16)
        self.layer9_conv = double_conv2d_bn(16,8)
        self.layer10_conv = nn.Conv2d(8,1,kernel_size=3,
                                     stride=1,padding=1,bias=True)
        
        self.deconv1 = deconv2d_bn(128,64)
        self.deconv2 = deconv2d_bn(64,32)
        self.deconv3 = deconv2d_bn(32,16)
        self.deconv4 = deconv2d_bn(16,8)
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self,x):
        conv1 = self.layer1_conv(x)
        pool1 = F.max_pool2d(conv1,2)
        
        conv2 = self.layer2_conv(pool1)
        pool2 = F.max_pool2d(conv2,2)
        
        conv3 = self.layer3_conv(pool2)
        pool3 = F.max_pool2d(conv3,2)
        
        conv4 = self.layer4_conv(pool3)
        pool4 = F.max_pool2d(conv4,2)
        
        conv5 = self.layer5_conv(pool4)
        
        convt1 = self.deconv1(conv5)
        concat1 = torch.cat([convt1,conv4],dim=1)
        conv6 = self.layer6_conv(concat1)
        
        convt2 = self.deconv2(conv6)
        concat2 = torch.cat([convt2,conv3],dim=1)
        conv7 = self.layer7_conv(concat2)
        
        convt3 = self.deconv3(conv7)
        concat3 = torch.cat([convt3,conv2],dim=1)
        conv8 = self.layer8_conv(concat3)
        
        convt4 = self.deconv4(conv8)
        concat4 = torch.cat([convt4,conv1],dim=1)
        conv9 = self.layer9_conv(concat4)
        outp = self.layer10_conv(conv9)
        outp = self.sigmoid(outp)
        return outp

In [3]:
model = Unet()
inp = torch.rand(10,1,224,224)
outp = model(inp)
print(outp.shape)
# ==> torch.Size([10, 1, 224, 224])

torch.Size([10, 1, 224, 224])


# torch 常用函数

## 1.维度变换
`permute`(dims)

In [5]:
torch.manual_seed(1234)
x = torch.randn(8, 1024, 16, 512) 
x.size() 

torch.Size([8, 1024, 16, 512])

In [10]:
x.permute(0, 1, 3, 2).size()

torch.Size([8, 1024, 512, 16])

`transpose`与`permute`的异同

In [15]:
x.transpose(0,3).transpose(2,1).transpose(3,2).shape

torch.Size([512, 16, 8, 1024])

`permute`函数与`contiguous`、`view`函数之关联

`contiguous`：`view`只能作用在`contiguous`的`variable`上，如果在`view`之前调用了`transpose`、`permute`等，就需要调用`contiguous()`来返回一个`contiguous` copy；

In [11]:
a = torch.ones(10, 10) 
a.is_contiguous()  

True

In [12]:
a.transpose(0, 1).is_contiguous()

False

In [13]:
a.transpose(0, 1).contiguous().is_contiguous() 

True

In [14]:
import numpy as np

a=np.array([[[1,2,3],[4,5,6]]])
unpermuted=torch.tensor(a)
print(unpermuted.size())              #  ——>  torch.Size([1, 2, 3])

permuted=unpermuted.permute(2,0,1)
print(permuted.size())                #  ——>  torch.Size([3, 1, 2])

view_test = unpermuted.view(1,3,2)
print(view_test.size())               #  ——>  torch.Size([1, 3, 2])

torch.Size([1, 2, 3])
torch.Size([3, 1, 2])
torch.Size([1, 3, 2])


# test

In [None]:
b = torch.Tensor([1,2,3,4])
b = torch.randn(1, 4)

In [65]:
torch.manual_seed(1234)
x = torch.randn(2, 4, 6) 
y = torch.randn(2, 4, 6)
print(x, '\n',y)

tensor([[[-0.1117, -0.4966,  0.1631, -0.8817,  0.0539,  0.6684],
         [-0.0597, -0.4675, -0.2153,  0.8840, -0.7584, -0.3689],
         [-0.3424, -1.4020,  0.3206, -1.0219,  0.7988, -0.0923],
         [-0.7049, -1.6024,  0.2891,  0.4899, -0.3853, -0.7120]],

        [[-0.1706, -1.4594,  0.2207,  0.2463, -1.3248,  0.6970],
         [-0.6631,  1.2158, -1.4949,  0.8810, -1.1786, -0.9340],
         [-0.5675, -0.2772, -2.1834,  0.3668,  0.9380,  0.0078],
         [-0.3139, -1.1567,  1.8409, -1.0174,  1.2192,  0.1601]]]) 
 tensor([[[ 1.5985, -0.0469, -1.5270, -2.0143, -1.5173,  0.3877],
         [-1.1849,  0.6897,  1.3232,  1.8169,  0.6808,  0.7244],
         [ 0.0323, -1.6593, -1.8773,  0.7372,  0.9257,  0.9247],
         [ 0.1825, -0.0737,  0.3147, -1.0369,  0.2100,  0.6144]],

        [[ 0.0628, -0.3297, -1.7970,  0.8728,  0.7670, -0.1138],
         [-0.9428,  0.7540,  0.1407, -0.6937, -0.6159, -0.7295],
         [ 0.4308,  0.2862, -0.2481,  0.2040,  0.8519, -1.4102],
         [-0.1071

*乘

In [23]:
(x*y).shape # 对应元素相乘 
torch.mul(x,y).shape #等价

torch.Size([2, 4, 6])

点乘

In [27]:
a = torch.ones(3,4)
b = torch.ones(4,2)
torch.mm(a, b)

tensor([[4., 4.],
        [4., 4.],
        [4., 4.]])

In [28]:
torch.mm(b, a) # 不能换位置

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x2 and 3x4)

torch.matmul是torch.mm的broadcast版本.

In [60]:
a = torch.ones(5,2,4)*2
b = torch.ones(5,4,2)
print (torch.matmul(a, b).shape)
# torch.matmul(a, b)

torch.Size([5, 2, 2])


`Softmax`

In [38]:
m = nn.Softmax(dim=1)
input = torch.randn(2, 3)
print(input)
output = m(input)
print (output)
F.softmax(input, dim=0)

tensor([[ 1.4696, -1.3284,  1.9946],
        [-0.8209,  1.0061, -1.0664]])
tensor([[0.3634, 0.0221, 0.6144],
        [0.1250, 0.7772, 0.0978]])


tensor([[0.9081, 0.0883, 0.9553],
        [0.0919, 0.9117, 0.0447]])

矩阵连接

In [45]:
a = torch.ones(5,4,4)
b = torch.ones(5,4,2)*2
torch.cat((b,a), 2).shape

torch.Size([5, 4, 6])

In [53]:

def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S, [K]]
    Return:
        new_points:, indexed points data, [B, S, [K], C]
    """
    raw_size = idx.size()
    idx = idx.reshape(raw_size[0], -1)
    res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1)))
    return res.reshape(*raw_size, -1)

In [57]:
a = torch.ones(4,64,6)
b = torch.randn(4,64,16)
raw_size = b.size()
idx = b.reshape(raw_size[0], -1)
idx[..., None].expand(-1, -1, 6).shape

torch.Size([4, 1024, 6])

In [59]:
b = torch.ones(2,4,6)
print(b)
torch.einsum('bmf->bf', b)

tensor([[[1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.]]])


tensor([[4., 4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4., 4.]])

In [3]:
torch.manual_seed(1234)
b = torch.randn(2,4,6)
b

tensor([[[-0.1117, -0.4966,  0.1631, -0.8817,  0.0539,  0.6684],
         [-0.0597, -0.4675, -0.2153,  0.8840, -0.7584, -0.3689],
         [-0.3424, -1.4020,  0.3206, -1.0219,  0.7988, -0.0923],
         [-0.7049, -1.6024,  0.2891,  0.4899, -0.3853, -0.7120]],

        [[-0.1706, -1.4594,  0.2207,  0.2463, -1.3248,  0.6970],
         [-0.6631,  1.2158, -1.4949,  0.8810, -1.1786, -0.9340],
         [-0.5675, -0.2772, -2.1834,  0.3668,  0.9380,  0.0078],
         [-0.3139, -1.1567,  1.8409, -1.0174,  1.2192,  0.1601]]])

In [6]:
b.shape

torch.Size([2, 4, 6])

In [72]:
torch.max(b,1)

torch.return_types.max(
values=tensor([[-0.0597, -0.4675,  0.3206,  0.8840,  0.7988,  0.6684],
        [-0.1706,  1.2158,  1.8409,  0.8810,  1.2192,  0.6970]]),
indices=tensor([[1, 1, 2, 1, 2, 0],
        [0, 1, 3, 1, 3, 0]]))

In [73]:
F.softmax(b, dim=-1)

tensor([[[0.1466, 0.0998, 0.1930, 0.0679, 0.1730, 0.3198],
         [0.1582, 0.1052, 0.1354, 0.4064, 0.0787, 0.1161],
         [0.1218, 0.0422, 0.2364, 0.0617, 0.3814, 0.1564],
         [0.1022, 0.0417, 0.2762, 0.3376, 0.1407, 0.1015]],

        [[0.1435, 0.0396, 0.2122, 0.2177, 0.0453, 0.3417],
         [0.0713, 0.4667, 0.0310, 0.3339, 0.0426, 0.0544],
         [0.0880, 0.1176, 0.0175, 0.2240, 0.3965, 0.1564],
         [0.0596, 0.0256, 0.5137, 0.0295, 0.2759, 0.0957]]])