In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [74]:
def get_indices(cv, S):
    """
    build indices from a cost volume
    
    cv: (B, C, H, W)
    indices: (B, C_h, C_w, H, W)
    """
    C_h, C_w = cv.size()
    indices = torch.zeros((2*S+1, 2*S+1, 2))
    indices[..., 0] = torch.linspace(-S, S, 2*S+1)[:, None]
    indices[..., 1] = torch.linspace(-S, S, 2*S+1)
    # indices = indices.view(1, 2*S+1, 2*S+1, 2, 1, 1).expand(B, C_h, C_w, 2, H, W)
    # print(indices[..., 0, 0,0])
    # quit()


    indices_y = torch.linspace(-S, S, 2*S+1).view(2*S+1, 1).expand(2*S+1, 2*S+1)
    indices_x = torch.linspace(-S, S, 2*S+1).view(1, 2*S+1).expand(2*S+1, 2*S+1)

    return indices_y, indices_x

In [118]:
def gen_flow_soft(corr):
    C_h, C_w = corr.size()
    softmax = F.softmax(corr.view(C_h*C_w), dim = 0).view(C_h, C_w)
    print(softmax)
    indices_y, indices_x = get_indices(corr, C_h // 2)
    soft_argmax_y = (softmax * indices_y).sum()
    soft_argmax_x = (softmax * indices_x).sum()
    return torch.stack([soft_argmax_x, soft_argmax_y], dim = 0)

In [119]:
def gen_flow_hard(corr):
    max, _ = corr.max(1)
    _, flow_y = max.max(0)

    max, _ = corr.max(0)
    _, flow_x = max.max(0)
    flow_hard = torch.stack([flow_x, flow_y], dim = 0).float()# - 4
    return flow_hard

# Correlation

In [193]:
corr = torch.Tensor(
    [
        [1   ,   2,    3],
        [0,   3,       1],
        [0,   4,       0],
    ])

# 坐标

In [194]:
y, x = get_indices(corr, 1)
print(x + 1)
print()
print(y + 1)

tensor([[ 0.,  1.,  2.],
        [ 0.,  1.,  2.],
        [ 0.,  1.,  2.]])

tensor([[ 0.,  0.,  0.],
        [ 1.,  1.,  1.],
        [ 2.,  2.,  2.]])


# softmax

In [195]:
softmax = F.softmax(corr.view(9), dim = 0).view(3,3)
softmax

tensor([[ 0.0246,  0.0668,  0.1816],
        [ 0.0090,  0.1816,  0.0246],
        [ 0.0090,  0.4937,  0.0090]])

# indices * softmax

In [196]:
print((x + 1) * softmax)
print()
print((y + 1) * softmax)

tensor([[ 0.0000,  0.0668,  0.3632],
        [ 0.0000,  0.1816,  0.0492],
        [ 0.0000,  0.4937,  0.0181]])

tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.0090,  0.1816,  0.0246],
        [ 0.0181,  0.9874,  0.0181]])


# sum

In [197]:
print(((x + 1) * softmax).sum())
print()
print(((y + 1) * softmax).sum())

tensor(1.1726)

tensor(1.2388)


# 直接argmax

In [198]:
gen_flow_hard(corr)

tensor([ 1.,  2.])

In [122]:
gen_flow_soft(corr) + 1

tensor([[ 0.0246,  0.0668,  0.1816],
        [ 0.0090,  0.1816,  0.0246],
        [ 0.0090,  0.4937,  0.0090]])


tensor([ 1.1726,  1.2388])

In [123]:
gen_flow_hard(corr)

tensor([ 1.,  2.])

# torch.nn.functional.softmax

In [33]:
x = torch.Tensor(np.array((-100, 0, 1, 2, 3, 5, 0, 0, 0)))
F.softmax(x, dim = 0)

tensor([ 0.0000,  0.0055,  0.0149,  0.0405,  0.1100,  0.8128,  0.0055,
         0.0055,  0.0055])

In [32]:
x = torch.Tensor(np.array((0, 0, 1, 2, 3, 5, 0, 0, 0)))
softmax = F.softmax(x, dim = 0)
indices = torch.linspace(-4, 4, 9)
print(indices)
print('')
print((softmax * indices).sum())
print(x.max(0)[1] - 4)

tensor([-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.])

tensor(0.7494)
tensor(1)


In [25]:
map = torch.arange(-4, 5).view(1, 9, 1, 1)
F.conv2d(softmax.view(1, ), map)

RuntimeError: input has less dimensions than expected

# torch.einsum

# Cost Volume实验记录

In [None]:
def v5():
    src = torch.Tensor(np.arange(8*3*19*19).reshape((8, 3, 19, 19)))
    tgt = torch.ones((8, 3, 19, 19))

In [28]:
def v4():
    src = torch.Tensor(np.arange(8*3*19*19).reshape((8, 3, 19, 19)))
    tgt = torch.ones((8, 3, 19, 19))

    S = 4
    B, C, H, W = src.size()
    output = torch.zeros_like(torch.zeros((B, (S*2+1)**2, H, W)))
    output[:,0] = (tgt*src).sum(1)

    I = 1
    for i in range(1, S + 1):
        # tgt下移i像素并补0, src与之对应的部分为i之后的像素, output的上i个像素为0
        output[:,I,i:,:] = (tgt[:,:,:-i,:] * src[:,:,i:,:]).sum(1); I += 1
        output[:,I,:-i,:] = (tgt[:,:,i:,:] * src[:,:,:-i,:]).sum(1); I += 1
        output[:,I,:,i:] = (tgt[:,:,:,:-i] * src[:,:,:,i:]).sum(1); I += 1
        output[:,I,:,:-i] = (tgt[:,:,:,i:] * src[:,:,:,:-i]).sum(1); I += 1

        for j in range(1, S + 1):
            output[:,I,i:,j:] = (tgt[:,:,:-i,:-j] * src[:,:,i:,j:]).sum(1); I += 1
            output[:,I,:-i,:-j] = (tgt[:,:,i:,j:] * src[:,:,:-i,:-j]).sum(1); I += 1
            output[:,I,i:,:-j] = (tgt[:,:,:-i,j:] * src[:,:,i:,:-j]).sum(1); I += 1
            output[:,I,:-i,j:] = (tgt[:,:,i:,:-j] * src[:,:,:-i,j:]).sum(1); I += 1

In [30]:
%time v4()

CPU times: user 29.3 ms, sys: 3.53 ms, total: 32.8 ms
Wall time: 21.8 ms


In [9]:
import torch

src = torch.Tensor(np.arange(8*3*4*4).reshape((8, 3, 1, 4, 4)))
tgt = torch.ones((8, 3, 4, 4))

torch.einsum('bcihw,bchw->bihw', (src, tgt)).size()

torch.Size([8, 1, 4, 4])

## nn.Module

In [135]:
import torch.nn as nn
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels = 3, out_channels = 128, kernel_size = 3, stride = 1, padding = 1, dilation = 1, groups = 1, bias = True),
            nn.LeakyReLU(inplace = True))
        self.convs = [nn.Sequential(
            nn.Conv2d(in_channels = 3, out_channels = 128, kernel_size = 3, stride = 1, padding = 1, dilation = 1, groups = 1, bias = True),
            nn.LeakyReLU(inplace = True))]

In [137]:
net = Net()
for module in net.modules():
    print(module)

Net(
  (conv1): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(0.01, inplace)
  )
)
Sequential(
  (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): LeakyReLU(0.01, inplace)
)
Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
LeakyReLU(0.01, inplace)


调用`.cuda()`的影响

In [138]:
net = Net()
net.cuda()
for module in net.modules():
    print(module)

AssertionError: Torch not compiled with CUDA enabled

调用`nn.DataParallel`的影响

In [139]:
net = Net()
net = nn.DataParallel(net)
for module in net.modules():
    print(module)

DataParallel(
  (module): Net(
    (conv1): Sequential(
      (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(0.01, inplace)
    )
  )
)
Net(
  (conv1): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(0.01, inplace)
  )
)
Sequential(
  (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): LeakyReLU(0.01, inplace)
)
Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
LeakyReLU(0.01, inplace)


In [57]:
src = Variable(torch.Tensor(np.arange(8*3*4*4).reshape((8, 3,1, 4, 4))))
tgt = Variable(torch.ones((8, 3, 81, 4, 4)))

In [128]:
src = Variable(torch.zeros((8,192,6,8)))
tgt = Variable(torch.zeros((8,192,6,8)))

In [46]:
torch.matmul(src, tgt).size()

RuntimeError: The size of tensor a (3) must match the size of tensor b (81) at non-singleton dimension 2.

In [42]:
8*81*16

10368

In [11]:
src = Variable(torch.Tensor(np.arange(8*3*384*448).reshape((8, 3, 384, 448))))
tgt = Variable(torch.ones((8, 3, 384+8, 448+8)))

In [12]:
def v1():
    output = Variable(torch.zeros((8,81,384,448)))
    H = 384; W = 448

    for i in range(4, 5):
        for j in range(4, 5):
            a = src[:,:,i,j].unsqueeze(1)
            b = tgt[:,:,i-4:i+5,j-4:j+5].contiguous().view(8, 3, -1)
            print(a.size(), b.size())
            tmp = torch.matmul(a, b).squeeze(1)
            output[:,:,i,j] = tmp


In [13]:
v1()

torch.Size([8, 1, 3]) torch.Size([8, 3, 81])


In [61]:
import time
t = time.time()
v1()
time.time() - t

13.19740104675293

In [60]:
%timeit v1()

13.1 s ± 1.05 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


# torch.matmul()

In [2]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as F

In [3]:
B, C, H, W = 8, 128, 48, 48
S = 4
src = Variable(torch.ones((B, C, H, W)))
tgt = Variable(torch.ones((B, C, H, W)))
src.size(), tgt.size()

(torch.Size([8, 128, 48, 48]), torch.Size([8, 128, 48, 48]))

In [16]:
x_1 = src.transpose(1,2).transpose(2,3)
x_2 = F.pad(tgt, tuple([4 for _ in range(4)])).transpose(1,2).transpose(2,3)
mean_x_1 = torch.mean(x_1,3) 
mean_x_2 = torch.mean(x_2,3) 
sub_x_1 = x_1.sub(mean_x_1.expand_as(x_1))
sub_x_2 = x_2.sub(mean_x_2.expand_as(x_2))
st_dev_x_1 = torch.std(x_1,3) 
st_dev_x_2 = torch.std(x_2,3)

# TODO need optimize
out_vb = torch.zeros(1)
_y=0
_x=0
while _y < self.max_displacement*2+1:
    while _x < self.max_displacement*2+1:
        c_out = (torch.sum(sub_x_1*sub_x_2[:,_x:_x+x_1.size(1), _y:_y+x_1.size(2),:],3) / (st_dev_x_1*st_dev_x_2[:,_x:_x+x_1.size(1),
            _y:_y+x_1.size(2),:])).transpose(2,3).transpose(1,2)
        out_vb = torch.cat((out_vb,c_out),1) if len(out_vb.size())!=1 else c_out
        _x += self.stride_2
    _y += self.stride_2
return out_vb 


RuntimeError: The expanded size of the tensor (192) must match the existing size (224) at non-singleton dimension 3. at /pytorch/torch/lib/TH/generic/THTensor.c:309

In [21]:
def v2(src, tgt):
    if src.size(1) >= (S*2+1)**2:
        output = torch.zeros_like(src)[:,:(S*2+1)**2,:,:]
    else:
        output = F.pad(torch.zeros_like(src), (0,0,0,0,(S*2+1)**2 - src.size(1),0))
    tgt = F.pad(tgt, [S]*4)
    for i in range(S, H):
        for j in range(S, W):
            output[:,:,i,j] = torch.matmul(src[:,:,i,j].unsqueeze(1), tgt[:,:,i-S:i+S+1,j-S:j+S+1].contiguous().view(B, C, -1)).squeeze(1)

In [18]:
def v3(src, tgt):
    tgt_neigh = [tgt]
    for i in range(1, S + 1):
        map_up    = torch.zeros_like(tgt); map_up[:,:,i:,:]     = tgt[:,:,:-i,:]
        map_down  = torch.zeros_like(tgt); map_down[:,:,:-i,:]  = tgt[:,:,i:,:]
        map_left  = torch.zeros_like(tgt); map_left[:,:,:,i:]   = tgt[:,:,:,:-i]
        map_right = torch.zeros_like(tgt); map_right[:,:,:,:-i] = tgt[:,:,:,i:]
        tgt_neigh.extend([map_up, map_down, map_left, map_right])

        for j in range(1, S + 1):
            map_ul = torch.zeros_like(tgt); map_ul[:,:,i:,j:]   = tgt[:,:,:-i,:-j]
            map_ll = torch.zeros_like(tgt); map_ll[:,:,:-i,j:]  = tgt[:,:,i:,:-j]
            map_ur = torch.zeros_like(tgt); map_ur[:,:,i:,:-j]  = tgt[:,:,:-i,j:]
            map_lr = torch.zeros_like(tgt); map_lr[:,:,:-i,:-j] = tgt[:,:,i:,j:]
            tgt_neigh.extend([map_ul, map_ll, map_ur, map_lr])

    tgt_neigh = torch.stack(tgt_neigh, dim = 2)

    output = (src.unsqueeze(dim = 2) * tgt_neigh).sum(dim = 1)

In [41]:
def v4(src, tgt):
    f = lambda x: (x*src).sum(1)
    outputs = [f(tgt)]
    
    for i in range(1, S + 1):
        map_up = torch.zeros_like(tgt); map_up[:,:,i:,:] = tgt[:,:,:-i,:]
        map_down  = torch.zeros_like(tgt); map_down[:,:,:-i,:]  = tgt[:,:,i:,:]
        map_left  = torch.zeros_like(tgt); map_left[:,:,:,i:]   = tgt[:,:,:,:-i]
        map_right = torch.zeros_like(tgt); map_right[:,:,:,:-i] = tgt[:,:,:,i:]
        outputs.extend(list(map(f, [map_up, map_down, map_left, map_right])))

        for j in range(1, S + 1):
            map_ul = torch.zeros_like(tgt); map_ul[:,:,i:,j:]   = tgt[:,:,:-i,:-j]
            map_ll = torch.zeros_like(tgt); map_ll[:,:,:-i,j:]  = tgt[:,:,i:,:-j]
            map_ur = torch.zeros_like(tgt); map_ur[:,:,i:,:-j]  = tgt[:,:,:-i,j:]
            map_lr = torch.zeros_like(tgt); map_lr[:,:,:-i,:-j] = tgt[:,:,i:,j:]
            outputs.extend(list(map(f, [map_ul, map_ll, map_ur, map_lr])))

In [44]:
def v5(src, tgt):
    f = lambda x: (x*src).sum(1)
    outputs = [f(tgt)]
    
    for i in range(1, S + 1):
        map_up = F.pad(tgt[:,:,:-i,:], (0,0,0,i))
        map_down  = F.pad(tgt[:,:,i:,:], (0,0,0,i))
        map_left  = F.pad(tgt[:,:,:,:-i], (i,0))
        map_right = F.pad(tgt[:,:,:,i:], (i,0))
        outputs.extend(list(map(f, [map_up, map_down, map_left, map_right])))

        for j in range(1, S + 1):
            map_ul = F.pad(tgt[:,:,:-i,:-j], (j,0,i,0))
            map_ll = F.pad(tgt[:,:,i:,:-j], (j,0,i,0))
            map_ur = F.pad(tgt[:,:,:-i,j:], (j,0,i,0))
            map_lr = F.pad(tgt[:,:,i:,j:], (j,0,i,0))
            outputs.extend(list(map(f, [map_ul, map_ll, map_ur, map_lr])))

In [45]:
v4(src, tgt)

In [46]:
%timeit v4(src, tgt)
%timeit v5(src, tgt)

550 ms ± 59.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
490 ms ± 19.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
def v5(src, tgt):
    def pad_to_src(x):
        
    f = lambda x: (x*src).sum(1)
    outputs = [f(tgt)]
    
    for i in range(1, S + 1):
        map_up = torch.zeros_like(tgt); map_up[:,:,i:,:] = tgt[:,:,:-i,:]
        map_down  = torch.zeros_like(tgt); map_down[:,:,:-i,:]  = tgt[:,:,i:,:]
        map_left  = torch.zeros_like(tgt); map_left[:,:,:,i:]   = tgt[:,:,:,:-i]
        map_right = torch.zeros_like(tgt); map_right[:,:,:,:-i] = tgt[:,:,:,i:]
        outputs.extend(list(map(f, [map_up, map_down, map_left, map_right])))

        for j in range(1, S + 1):
            map_ul = torch.zeros_like(tgt); map_ul[:,:,i:,j:]   = tgt[:,:,:-i,:-j]
            map_ll = torch.zeros_like(tgt); map_ll[:,:,:-i,j:]  = tgt[:,:,i:,:-j]
            map_ur = torch.zeros_like(tgt); map_ur[:,:,i:,:-j]  = tgt[:,:,:-i,j:]
            map_lr = torch.zeros_like(tgt); map_lr[:,:,:-i,:-j] = tgt[:,:,i:,j:]
            outputs.extend(list(map(f, [map_ul, map_ll, map_ur, map_lr])))

In [1]:
print('Version 2')
%time v2(src, tgt)
print('Version 3 (占用显存过多, remove)')
# %time v3(src, tgt)

Version 2


NameError: name 'v2' is not defined

Version 3 (占用显存过多, remove)


In [38]:
output = torch.autograd.Variable(torch.zeros((8,81,384,448)))
H = 384; W = 448
# TODO: so slow! find the batch dot way
for i in range(H):
    for j in range(W):
        # TODO: pytorch的einsum该怎么写????
        tmp = [torch.matmul(src[:,:,i,j].unsqueeze(1), tgt[:,:,I,J].unsqueeze(2)) for I in range(i-4, i+5) for J in range(j-4, j+5)]
        tmp = torch.stack(tmp, dim = 1).squeeze()
        output[:,:,i,j] = tmp

KeyboardInterrupt: 

In [24]:
x = a[:,:,100,100].unsqueeze(1)
x.size()

torch.Size([8, 1, 3])

In [25]:
x

Variable containing:
(0 ,.,.) = 
  1.2900e+04  2.9284e+04  4.5668e+04

(1 ,.,.) = 
  6.2052e+04  7.8436e+04  9.4820e+04

(2 ,.,.) = 
  1.1120e+05  1.2759e+05  1.4397e+05

(3 ,.,.) = 
  1.6036e+05  1.7674e+05  1.9312e+05

(4 ,.,.) = 
  2.0951e+05  2.2589e+05  2.4228e+05

(5 ,.,.) = 
  2.5866e+05  2.7504e+05  2.9143e+05

(6 ,.,.) = 
  3.0781e+05  3.2420e+05  3.4058e+05

(7 ,.,.) = 
  3.5696e+05  3.7335e+05  3.8973e+05
[torch.FloatTensor of size 8x1x3]

In [26]:
# (8,1,3) * (8,3,128,128) 
torch.matmul(x, b.view(8,3,-1)).view(8,1,128,128)

Variable containing:
( 0 , 0 ,.,.) = 
  8.7852e+04  8.7852e+04  8.7852e+04  ...   8.7852e+04  8.7852e+04  8.7852e+04
  8.7852e+04  8.7852e+04  8.7852e+04  ...   8.7852e+04  8.7852e+04  8.7852e+04
  8.7852e+04  8.7852e+04  8.7852e+04  ...   8.7852e+04  8.7852e+04  8.7852e+04
                 ...                   ⋱                   ...                
  8.7852e+04  8.7852e+04  8.7852e+04  ...   8.7852e+04  8.7852e+04  8.7852e+04
  8.7852e+04  8.7852e+04  8.7852e+04  ...   8.7852e+04  8.7852e+04  8.7852e+04
  8.7852e+04  8.7852e+04  8.7852e+04  ...   8.7852e+04  8.7852e+04  8.7852e+04
      ⋮  

( 1 , 0 ,.,.) = 
  2.3531e+05  2.3531e+05  2.3531e+05  ...   2.3531e+05  2.3531e+05  2.3531e+05
  2.3531e+05  2.3531e+05  2.3531e+05  ...   2.3531e+05  2.3531e+05  2.3531e+05
  2.3531e+05  2.3531e+05  2.3531e+05  ...   2.3531e+05  2.3531e+05  2.3531e+05
                 ...                   ⋱                   ...                
  2.3531e+05  2.3531e+05  2.3531e+05  ...   2.3531e+05  2.3531e+0

In [27]:
c = Variable(torch.Tensor(np.arange(2*3).reshape((2,3))))
d = Variable(torch.ones((2, 3, 9, 9)))

In [28]:
torch.matmul(c.unsqueeze(1), d.view(2,3,-1)).view(2,1,9,9)

Variable containing:
(0 ,0 ,.,.) = 
   3   3   3   3   3   3   3   3   3
   3   3   3   3   3   3   3   3   3
   3   3   3   3   3   3   3   3   3
   3   3   3   3   3   3   3   3   3
   3   3   3   3   3   3   3   3   3
   3   3   3   3   3   3   3   3   3
   3   3   3   3   3   3   3   3   3
   3   3   3   3   3   3   3   3   3
   3   3   3   3   3   3   3   3   3

(1 ,0 ,.,.) = 
  12  12  12  12  12  12  12  12  12
  12  12  12  12  12  12  12  12  12
  12  12  12  12  12  12  12  12  12
  12  12  12  12  12  12  12  12  12
  12  12  12  12  12  12  12  12  12
  12  12  12  12  12  12  12  12  12
  12  12  12  12  12  12  12  12  12
  12  12  12  12  12  12  12  12  12
  12  12  12  12  12  12  12  12  12
[torch.FloatTensor of size 2x1x9x9]

CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main.py --train --num_levels 1 --weights 1 --dataset FlyingChairs --dataset_dir data/FlyingChairs/ --log_dir train_log/0421-Lv1 --batch_size 28 --lv_chs 192 --crop_shape 384 448 --resize_scale 1/16 

In [13]:
import torch.nn as nn
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 128, kernel_size = 3, stride = 2, padding = 1)
                

In [14]:
net = Net()
net.training

True

In [15]:
net.train()
net.training

True

In [16]:
import torch
with torch.no_grad():
    print(net.training)


True


In [17]:
optimizer = torch.optim.Adam(net.parameters(), 1e-4)

获取learning rate

In [20]:
optimizer.param_groups[0]['lr']

0.0001

grid sample

In [25]:
B, C, H, W = 1, 3, 5, 5
torchHorizontal = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, 1, H, W)
torchVertical = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, 1, H, W)
grid = torch.cat([torchHorizontal, torchVertical], 1).permute(0, 2, 3, 1)
print(grid)

tensor([[[[-1.0000, -1.0000],
          [-0.5000, -1.0000],
          [ 0.0000, -1.0000],
          [ 0.5000, -1.0000],
          [ 1.0000, -1.0000]],

         [[-1.0000, -0.5000],
          [-0.5000, -0.5000],
          [ 0.0000, -0.5000],
          [ 0.5000, -0.5000],
          [ 1.0000, -0.5000]],

         [[-1.0000,  0.0000],
          [-0.5000,  0.0000],
          [ 0.0000,  0.0000],
          [ 0.5000,  0.0000],
          [ 1.0000,  0.0000]],

         [[-1.0000,  0.5000],
          [-0.5000,  0.5000],
          [ 0.0000,  0.5000],
          [ 0.5000,  0.5000],
          [ 1.0000,  0.5000]],

         [[-1.0000,  1.0000],
          [-0.5000,  1.0000],
          [ 0.0000,  1.0000],
          [ 0.5000,  1.0000],
          [ 1.0000,  1.0000]]]])


In [31]:
import numpy as np
x = torch.Tensor(np.arange(B*C*H*W).reshape((B, C, H, W)))
(x == torch.nn.functional.grid_sample(x, grid, mode='bilinear', padding_mode='zeros')).all()

tensor(1, dtype=torch.uint8)

为optimizer设置scheduler

In [None]:
# def lr_lambda(epoch):
    #     iters = epoch * iter_per_epoch
    #     if iters < 4e+5: return 1e-4
    #     elif 4e+5 <= iters < 6e+5: return 5e-5
    #     elif 6e+5 <= iters < 8e+5: return 2e-5
    #     elif 8e+5 <= iters < 1e+6: return 1e-5
    #     else: return 5e-6
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)