In [2]:
import glob
import json
import matplotlib.pyplot as plt
from multiprocessing import Manager
import numpy as np
import os
import time
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import kaolin as kal
from kaolin.models.PointNet2 import furthest_point_sampling
from kaolin.models.PointNet2 import fps_gather_by_index
from kaolin.models.PointNet2 import ball_query
from kaolin.models.PointNet2 import group_gather_by_index
from kaolin.models.PointNet2 import three_nn
from kaolin.models.PointNet2 import three_interpolate

import tensorflow as tf

In [3]:
class SceneflowDataset(Dataset):
    def __init__(self, npoints=2048, root='data/data_processed_maxcut_35_20k_2k_8192', train=True, cache=None):
        self.npoints = npoints
        self.train = train
        self.root = root
        if self.train:
            self.datapath = glob.glob(os.path.join(self.root, 'TRAIN*.npz'))
        else:
            self.datapath = glob.glob(os.path.join(self.root, 'TEST*.npz'))
        
        if cache is None:
            self.cache = {}
        else:
            self.cache = cache
        
        self.cache_size = 30000

        ###### deal with one bad datapoint with nan value
        self.datapath = [d for d in self.datapath if 'TRAIN_C_0140_left_0006-0' not in d]
        ######

    def __getitem__(self, index):
        if index in self.cache:
            pos1, pos2, color1, color2, flow, mask1 = self.cache[index]
        else:
            fn = self.datapath[index]
            with open(fn, 'rb') as fp:
                data = np.load(fp)
                pos1 = data['points1'].astype('float32')
                pos2 = data['points2'].astype('float32')
                color1 = data['color1'].astype('float32') / 255
                color2 = data['color2'].astype('float32') / 255
                flow = data['flow'].astype('float32')
                mask1 = data['valid_mask1']

            if len(self.cache) < self.cache_size:
                self.cache[index] = (pos1, pos2, color1, color2, flow, mask1)

        if self.train:
            n1 = pos1.shape[0]
            sample_idx1 = np.random.choice(n1, self.npoints, replace=False)
            n2 = pos2.shape[0]
            sample_idx2 = np.random.choice(n2, self.npoints, replace=False)

            pos1 = pos1[sample_idx1, :]
            pos2 = pos2[sample_idx2, :]
            color1 = color1[sample_idx1, :]
            color2 = color2[sample_idx2, :]
            flow = flow[sample_idx1, :]
            mask1 = mask1[sample_idx1]
        else:
            pos1 = pos1[:self.npoints, :]
            pos2 = pos2[:self.npoints, :]
            color1 = color1[:self.npoints, :]
            color2 = color2[:self.npoints, :]
            flow = flow[:self.npoints, :]
            mask1 = mask1[:self.npoints]

        pos1_center = np.mean(pos1, 0)
        pos1 -= pos1_center
        pos2 -= pos1_center
        
        pos1 = torch.from_numpy(pos1).t()
        pos2 = torch.from_numpy(pos2).t()
        color1 = torch.from_numpy(color1).t()
        color2 = torch.from_numpy(color2).t()
        flow = torch.from_numpy(flow).t()
        mask1 = torch.from_numpy(mask1)

        return pos1, pos2, color1, color2, flow, mask1

    def __len__(self):
        return len(self.datapath)
    
train_set = SceneflowDataset(train=True)
points1, points2, color1, color2, flow, mask1 = train_set[0]

print(points1.shape, points1.dtype)
print(points2.shape, points2.dtype)
print(color1.shape, color1.dtype)
print(color2.shape, color2.dtype)
print(flow.shape, flow.dtype)
print(mask1.shape, mask1.dtype)


torch.Size([3, 2048]) torch.float32
torch.Size([3, 2048]) torch.float32
torch.Size([3, 2048]) torch.float32
torch.Size([3, 2048]) torch.float32
torch.Size([3, 2048]) torch.float32
torch.Size([2048]) torch.bool


In [4]:
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

def pdist2squared(x, y):
    xx = (x**2).sum(dim=1).unsqueeze(2)
    yy = (y**2).sum(dim=1).unsqueeze(1)
    dist = xx + yy - 2.0 * torch.bmm(x.permute(0, 2, 1), y)
    dist[dist != dist] = 0
    dist = torch.clamp(dist, 0.0, np.inf)
    return dist

def parameter_count(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class ClippedStepLR(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, step_size, min_lr, gamma=0.1, last_epoch=-1):
        self.step_size = step_size
        self.min_lr = min_lr
        self.gamma = gamma
        super(ClippedStepLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        return [max(base_lr * self.gamma ** (self.last_epoch // self.step_size), self.min_lr)
                for base_lr in self.base_lrs]
    
def criterion(pred_flow, flow, mask):
    loss = torch.mean(mask * torch.sum((pred_flow - flow) * (pred_flow - flow), dim=1) / 2.0)
    return loss

def error(pred, labels, mask):
    pred = pred.permute(0,2,1).cpu().numpy()
    labels = labels.permute(0,2,1).cpu().numpy()
    mask = mask.cpu().numpy()
    
    err = np.sqrt(np.sum((pred - labels)**2, 2) + 1e-20)

    gtflow_len = np.sqrt(np.sum(labels*labels, 2) + 1e-20) # B,N
    acc050 = np.sum(np.logical_or((err <= 0.05)*mask, (err/gtflow_len <= 0.05)*mask), axis=1)
    acc010 = np.sum(np.logical_or((err <= 0.1)*mask, (err/gtflow_len <= 0.1)*mask), axis=1)

    mask_sum = np.sum(mask, 1)
    acc050 = acc050[mask_sum > 0] / mask_sum[mask_sum > 0]
    acc050 = np.mean(acc050)
    acc010 = acc010[mask_sum > 0] / mask_sum[mask_sum > 0]
    acc010 = np.mean(acc010)

    epe = np.sum(err * mask, 1)[mask_sum > 0] / mask_sum[mask_sum > 0]
    epe = np.mean(epe)
    return epe, acc050, acc010

In [5]:
class Sample(nn.Module):
    def __init__(self, num_points):
        super(Sample, self).__init__()
        
        self.num_points = num_points
        
    def forward(self, points):
        new_points_ind = furthest_point_sampling(points.permute(0, 2, 1).contiguous(), self.num_points)
        new_points = fps_gather_by_index(points, new_points_ind)
        return new_points
    
class Group(nn.Module):
    def __init__(self, radius, num_samples, knn=False):
        super(Group, self).__init__()
        
        self.radius = radius
        self.num_samples = num_samples
        self.knn = knn
        
    def forward(self, points, new_points, features):
        if self.knn:
            dist = pdist2squared(points, new_points)
            ind = dist.topk(self.num_samples, dim=1, largest=False)[1].int().permute(0, 2, 1).contiguous()
        else:
            ind = ball_query(self.radius, self.num_samples, points.permute(0, 2, 1).contiguous(),
                             new_points.permute(0, 2, 1).contiguous(), False)
        grouped_points = group_gather_by_index(points, ind)
        grouped_points -= new_points.unsqueeze(3)
        grouped_features = group_gather_by_index(features, ind)
        new_features = torch.cat([grouped_points, grouped_features], dim=1)
        return new_features

class SetConv(nn.Module):
    def __init__(self, num_points, radius, num_samples, in_channels, out_channels):
        super(SetConv, self).__init__()
        
        self.sample = Sample(num_points)
        self.group = Group(radius, num_samples)
        
        layers = []
        out_channels = [in_channels+3, *out_channels]
        for i in range(1, len(out_channels)):
            layers += [nn.Conv2d(out_channels[i - 1], out_channels[i], 1, bias=True), nn.BatchNorm2d(out_channels[i], eps=0.001), nn.ReLU()]
        self.conv = nn.Sequential(*layers)
        
    def forward(self, points, features):
        new_points = self.sample(points)
        new_features = self.group(points, new_points, features)
        new_features = self.conv(new_features)
        new_features = new_features.max(dim=3)[0]
        return new_points, new_features
    
class FlowEmbedding(nn.Module):
    def __init__(self, num_samples, in_channels, out_channels):
        super(FlowEmbedding, self).__init__()
        
        self.num_samples = num_samples
        
        self.group = Group(None, self.num_samples, knn=True)
        
        layers = []
        out_channels = [2*in_channels+3, *out_channels]
        for i in range(1, len(out_channels)):
            layers += [nn.Conv2d(out_channels[i - 1], out_channels[i], 1, bias=True), nn.BatchNorm2d(out_channels[i], eps=0.001), nn.ReLU()]
        self.conv = nn.Sequential(*layers)
        
    def forward(self, points1, points2, features1, features2):
        new_features = self.group(points2, points1, features2)
        new_features = torch.cat([new_features, features1.unsqueeze(3).expand(-1, -1, -1, self.num_samples)], dim=1)
        new_features = self.conv(new_features)
        new_features = new_features.max(dim=3)[0]
        return new_features
    
class SetUpConv(nn.Module):
    def __init__(self, num_samples, in_channels1, in_channels2, out_channels1, out_channels2):
        super(SetUpConv, self).__init__()
        
        self.group = Group(None, num_samples, knn=True)
        
        layers = []
        out_channels1 = [in_channels1+3, *out_channels1]
        for i in range(1, len(out_channels1)):
            layers += [nn.Conv2d(out_channels1[i - 1], out_channels1[i], 1, bias=True), nn.BatchNorm2d(out_channels1[i], eps=0.001), nn.ReLU()]
        self.conv1 = nn.Sequential(*layers)
        
        layers = []
        if len(out_channels1) == 1:
            out_channels2 = [in_channels1+in_channels2+3, *out_channels2]
        else:
            out_channels2 = [out_channels1[-1]+in_channels2, *out_channels2]
        for i in range(1, len(out_channels2)):
            layers += [nn.Conv2d(out_channels2[i - 1], out_channels2[i], 1, bias=True), nn.BatchNorm2d(out_channels2[i], eps=0.001), nn.ReLU()]
        self.conv2 = nn.Sequential(*layers)
        
    def forward(self, points1, points2, features1, features2):
        new_features = self.group(points1, points2, features1)
        new_features = self.conv1(new_features)
        new_features = new_features.max(dim=3)[0]
        new_features = torch.cat([new_features, features2], dim=1)
        new_features = new_features.unsqueeze(3)
        new_features = self.conv2(new_features)
        new_features = new_features.squeeze(3)
        return new_features
    
class FeaturePropagation(nn.Module):
    def __init__(self, in_channels1, in_channels2, out_channels):
        super(FeaturePropagation, self).__init__()
        
        layers = []
        out_channels = [in_channels1+in_channels2, *out_channels]
        for i in range(1, len(out_channels)):
            layers += [nn.Conv2d(out_channels[i - 1], out_channels[i], 1, bias=True), nn.BatchNorm2d(out_channels[i], eps=0.001), nn.ReLU()]
        self.conv = nn.Sequential(*layers)
        
    def forward(self, points1, points2, features1, features2):
        dist, ind = three_nn(points2.permute(0, 2, 1).contiguous(), points1.permute(0, 2, 1).contiguous())
        dist = dist * dist
        dist[dist < 1e-10] = 1e-10
        inverse_dist = 1.0 / dist
        norm = torch.sum(inverse_dist, dim=2, keepdim=True)
        weights = inverse_dist / norm
        #new_features = three_interpolate(features1, ind, weights) # wrong gradients
        new_features = torch.sum(group_gather_by_index(features1, ind) * weights.unsqueeze(1), dim = 3)
        new_features = torch.cat([new_features, features2], dim=1)
        new_features = self.conv(new_features.unsqueeze(3)).squeeze(3)
        return new_features

class FlowNet3D(nn.Module):
    def __init__(self):
        super(FlowNet3D, self).__init__()
        
        self.set_conv1 = SetConv(1024, 0.5, 16, 3, [32, 32, 64])
        self.set_conv2 = SetConv(256, 1.0, 16, 64, [64, 64, 128])
        self.flow_embedding = FlowEmbedding(64, 128, [128, 128, 128])
        self.set_conv3 = SetConv(64, 2.0, 8, 128, [128, 128, 256])
        self.set_conv4 = SetConv(16, 4.0, 8, 256, [256, 256, 512])
        self.set_upconv1 = SetUpConv(8, 512, 256, [], [256, 256])
        self.set_upconv2 = SetUpConv(8, 256, 256, [128, 128, 256], [256])
        self.set_upconv3 = SetUpConv(8, 256, 64, [128, 128, 256], [256])
        self.fp = FeaturePropagation(256, 3, [256, 256])
        self.classifier = nn.Sequential(
            nn.Conv1d(256, 128, 1, bias=True),
            nn.BatchNorm1d(128, eps=0.001),
            nn.ReLU(),
            nn.Conv1d(128, 3, 1, bias=True)
        )
         
    def forward(self, points1, points2, features1, features2):
        points1_1, features1_1 = self.set_conv1(points1, features1)
        points1_2, features1_2 = self.set_conv2(points1_1, features1_1)

        points2_1, features2_1 = self.set_conv1(points2, features2)
        points2_2, features2_2 = self.set_conv2(points2_1, features2_1)

        embedding = self.flow_embedding(points1_2, points2_2, features1_2, features2_2)
        
        points1_3, features1_3 = self.set_conv3(points1_2, embedding)
        points1_4, features1_4 = self.set_conv4(points1_3, features1_3)
        
        new_features1_3 = self.set_upconv1(points1_4, points1_3, features1_4, features1_3)
        new_features1_2 = self.set_upconv2(points1_3, points1_2, new_features1_3, torch.cat([features1_2, embedding], dim=1))
        new_features1_1 = self.set_upconv3(points1_2, points1_1, new_features1_2, features1_1)
        new_features1 = self.fp(points1_1, points1, new_features1_1, features1)

        flow = self.classifier(new_features1)
        
        return flow


In [6]:
tf_path = os.path.abspath('/path/to/model.ckpt')
init_vars = tf.train.list_variables(tf_path)
tf_vars = {}
for name, shape in init_vars:
    array = tf.train.load_variable(tf_path, name)
    tf_vars[name] = array

mapping = {'fa_layer4/conv_0/biases' : 'fp.conv.0.bias',
           'fa_layer4/conv_0/weights' : 'fp.conv.0.weight',
           'fa_layer4/conv_0/bn/beta' : 'fp.conv.1.bias',
           'fa_layer4/conv_0/bn/gamma' : 'fp.conv.1.weight',
           'fa_layer4/conv_0/bn/moving_mean' : 'fp.conv.1.running_mean',
           'fa_layer4/conv_0/bn/moving_variance' : 'fp.conv.1.running_var',
           'fa_layer4/conv_1/biases' : 'fp.conv.3.bias',
           'fa_layer4/conv_1/weights' : 'fp.conv.3.weight',
           'fa_layer4/conv_1/bn/beta' : 'fp.conv.4.bias',
           'fa_layer4/conv_1/bn/gamma' : 'fp.conv.4.weight',
           'fa_layer4/conv_1/bn/moving_mean' : 'fp.conv.4.running_mean',
           'fa_layer4/conv_1/bn/moving_variance' : 'fp.conv.4.running_var',
           'fc1/biases' : 'classifier.0.bias',
           'fc1/bn/beta' : 'classifier.1.bias',
           'fc1/bn/gamma' : 'classifier.1.weight',
           'fc1/bn/moving_mean' : 'classifier.1.running_mean',
           'fc1/bn/moving_variance' : 'classifier.1.running_var',
           'fc1/weights' : 'classifier.0.weight',
           'fc2/biases' : 'classifier.3.bias',
           'fc2/weights' : 'classifier.3.weight',
           'flow_embedding/conv_diff_0/biases' : 'flow_embedding.conv.0.bias',
           'flow_embedding/conv_diff_0/bn/beta' : 'flow_embedding.conv.1.bias',
           'flow_embedding/conv_diff_0/bn/gamma' : 'flow_embedding.conv.1.weight',
           'flow_embedding/conv_diff_0/bn/moving_mean' : 'flow_embedding.conv.1.running_mean',
           'flow_embedding/conv_diff_0/bn/moving_variance' : 'flow_embedding.conv.1.running_var',
           'flow_embedding/conv_diff_0/weights' : 'flow_embedding.conv.0.weight',
           'flow_embedding/conv_diff_1/biases' : 'flow_embedding.conv.3.bias',
           'flow_embedding/conv_diff_1/bn/beta' : 'flow_embedding.conv.4.bias',
           'flow_embedding/conv_diff_1/bn/gamma' : 'flow_embedding.conv.4.weight',
           'flow_embedding/conv_diff_1/bn/moving_mean' : 'flow_embedding.conv.4.running_mean',
           'flow_embedding/conv_diff_1/bn/moving_variance' : 'flow_embedding.conv.4.running_var',
           'flow_embedding/conv_diff_1/weights' : 'flow_embedding.conv.3.weight',
           'flow_embedding/conv_diff_2/biases' : 'flow_embedding.conv.6.bias',
           'flow_embedding/conv_diff_2/bn/beta' : 'flow_embedding.conv.7.bias',
           'flow_embedding/conv_diff_2/bn/gamma' : 'flow_embedding.conv.7.weight',
           'flow_embedding/conv_diff_2/bn/moving_mean' : 'flow_embedding.conv.7.running_mean',
           'flow_embedding/conv_diff_2/bn/moving_variance' : 'flow_embedding.conv.7.running_var',
           'flow_embedding/conv_diff_2/weights' : 'flow_embedding.conv.6.weight',
           'layer3/conv0/biases' : 'set_conv3.conv.0.bias',
           'layer3/conv0/bn/beta' : 'set_conv3.conv.1.bias',
           'layer3/conv0/bn/gamma' : 'set_conv3.conv.1.weight',
           'layer3/conv0/bn/moving_mean' : 'set_conv3.conv.1.running_mean',
           'layer3/conv0/bn/moving_variance' : 'set_conv3.conv.1.running_var',
           'layer3/conv0/weights' : 'set_conv3.conv.0.weight',
           'layer3/conv1/biases' : 'set_conv3.conv.3.bias',
           'layer3/conv1/bn/beta' : 'set_conv3.conv.4.bias',
           'layer3/conv1/bn/gamma' : 'set_conv3.conv.4.weight',
           'layer3/conv1/bn/moving_mean' : 'set_conv3.conv.4.running_mean',
           'layer3/conv1/bn/moving_variance' : 'set_conv3.conv.4.running_var',
           'layer3/conv1/weights' : 'set_conv3.conv.3.weight',
           'layer3/conv2/biases' : 'set_conv3.conv.6.bias',
           'layer3/conv2/bn/beta' : 'set_conv3.conv.7.bias',
           'layer3/conv2/bn/gamma' : 'set_conv3.conv.7.weight',
           'layer3/conv2/bn/moving_mean' : 'set_conv3.conv.7.running_mean',
           'layer3/conv2/bn/moving_variance' : 'set_conv3.conv.7.running_var',
           'layer3/conv2/weights' : 'set_conv3.conv.6.weight',
           'layer4/conv0/biases' : 'set_conv4.conv.0.bias',
           'layer4/conv0/bn/beta' : 'set_conv4.conv.1.bias',
           'layer4/conv0/bn/gamma' : 'set_conv4.conv.1.weight',
           'layer4/conv0/bn/moving_mean' : 'set_conv4.conv.1.running_mean',
           'layer4/conv0/bn/moving_variance' : 'set_conv4.conv.1.running_var',
           'layer4/conv0/weights' : 'set_conv4.conv.0.weight',
           'layer4/conv1/biases' : 'set_conv4.conv.3.bias',
           'layer4/conv1/bn/beta' : 'set_conv4.conv.4.bias',
           'layer4/conv1/bn/gamma' : 'set_conv4.conv.4.weight',
           'layer4/conv1/bn/moving_mean' : 'set_conv4.conv.4.running_mean',
           'layer4/conv1/bn/moving_variance' : 'set_conv4.conv.4.running_var',
           'layer4/conv1/weights' : 'set_conv4.conv.3.weight',
           'layer4/conv2/biases' : 'set_conv4.conv.6.bias',
           'layer4/conv2/bn/beta' : 'set_conv4.conv.7.bias',
           'layer4/conv2/bn/gamma' : 'set_conv4.conv.7.weight',
           'layer4/conv2/bn/moving_mean' : 'set_conv4.conv.7.running_mean',
           'layer4/conv2/bn/moving_variance' : 'set_conv4.conv.7.running_var',
           'layer4/conv2/weights' : 'set_conv4.conv.6.weight',
           'sa1/layer1/conv0/biases' : 'set_conv1.conv.0.bias',
           'sa1/layer1/conv0/bn/beta' : 'set_conv1.conv.1.bias',
           'sa1/layer1/conv0/bn/gamma' : 'set_conv1.conv.1.weight',
           'sa1/layer1/conv0/bn/moving_mean' : 'set_conv1.conv.1.running_mean',
           'sa1/layer1/conv0/bn/moving_variance' : 'set_conv1.conv.1.running_var',
           'sa1/layer1/conv0/weights' : 'set_conv1.conv.0.weight',
           'sa1/layer1/conv1/biases' : 'set_conv1.conv.3.bias',
           'sa1/layer1/conv1/bn/beta' : 'set_conv1.conv.4.bias',
           'sa1/layer1/conv1/bn/gamma' : 'set_conv1.conv.4.weight',
           'sa1/layer1/conv1/bn/moving_mean' : 'set_conv1.conv.4.running_mean',
           'sa1/layer1/conv1/bn/moving_variance' : 'set_conv1.conv.4.running_var',
           'sa1/layer1/conv1/weights' : 'set_conv1.conv.3.weight',
           'sa1/layer1/conv2/biases' : 'set_conv1.conv.6.bias',
           'sa1/layer1/conv2/bn/beta' : 'set_conv1.conv.7.bias',
           'sa1/layer1/conv2/bn/gamma' : 'set_conv1.conv.7.weight',
           'sa1/layer1/conv2/bn/moving_mean' : 'set_conv1.conv.7.running_mean',
           'sa1/layer1/conv2/bn/moving_variance' : 'set_conv1.conv.7.running_var',
           'sa1/layer1/conv2/weights' : 'set_conv1.conv.6.weight',
           'sa1/layer2/conv0/biases' : 'set_conv2.conv.0.bias',
           'sa1/layer2/conv0/bn/beta' : 'set_conv2.conv.1.bias',
           'sa1/layer2/conv0/bn/gamma' : 'set_conv2.conv.1.weight',
           'sa1/layer2/conv0/bn/moving_mean' : 'set_conv2.conv.1.running_mean',
           'sa1/layer2/conv0/bn/moving_variance' : 'set_conv2.conv.1.running_var',
           'sa1/layer2/conv0/weights' : 'set_conv2.conv.0.weight',
           'sa1/layer2/conv1/biases' : 'set_conv2.conv.3.bias',
           'sa1/layer2/conv1/bn/beta' : 'set_conv2.conv.4.bias',
           'sa1/layer2/conv1/bn/gamma' : 'set_conv2.conv.4.weight',
           'sa1/layer2/conv1/bn/moving_mean' : 'set_conv2.conv.4.running_mean',
           'sa1/layer2/conv1/bn/moving_variance' : 'set_conv2.conv.4.running_var',
           'sa1/layer2/conv1/weights' : 'set_conv2.conv.3.weight',
           'sa1/layer2/conv2/biases' : 'set_conv2.conv.6.bias',
           'sa1/layer2/conv2/bn/beta' : 'set_conv2.conv.7.bias',
           'sa1/layer2/conv2/bn/gamma' : 'set_conv2.conv.7.weight',
           'sa1/layer2/conv2/bn/moving_mean' : 'set_conv2.conv.7.running_mean',
           'sa1/layer2/conv2/bn/moving_variance' : 'set_conv2.conv.7.running_var',
           'sa1/layer2/conv2/weights' : 'set_conv2.conv.6.weight',
           'up_sa_layer1/post-conv0/biases' : 'set_upconv1.conv2.0.bias',
           'up_sa_layer1/post-conv0/bn/beta' : 'set_upconv1.conv2.1.bias',
           'up_sa_layer1/post-conv0/bn/gamma' : 'set_upconv1.conv2.1.weight',
           'up_sa_layer1/post-conv0/bn/moving_mean' : 'set_upconv1.conv2.1.running_mean',
           'up_sa_layer1/post-conv0/bn/moving_variance' : 'set_upconv1.conv2.1.running_var',
           'up_sa_layer1/post-conv0/weights' : 'set_upconv1.conv2.0.weight',
           'up_sa_layer1/post-conv1/biases' : 'set_upconv1.conv2.3.bias',
           'up_sa_layer1/post-conv1/bn/beta' : 'set_upconv1.conv2.4.bias',
           'up_sa_layer1/post-conv1/bn/gamma' : 'set_upconv1.conv2.4.weight',
           'up_sa_layer1/post-conv1/bn/moving_mean' : 'set_upconv1.conv2.4.running_mean',
           'up_sa_layer1/post-conv1/bn/moving_variance' : 'set_upconv1.conv2.4.running_var',
           'up_sa_layer1/post-conv1/weights' : 'set_upconv1.conv2.3.weight',
           'up_sa_layer2/conv0/biases' : 'set_upconv2.conv1.0.bias',
           'up_sa_layer2/conv0/bn/beta' : 'set_upconv2.conv1.1.bias',
           'up_sa_layer2/conv0/bn/gamma' : 'set_upconv2.conv1.1.weight',
           'up_sa_layer2/conv0/bn/moving_mean' : 'set_upconv2.conv1.1.running_mean',
           'up_sa_layer2/conv0/bn/moving_variance' : 'set_upconv2.conv1.1.running_var',
           'up_sa_layer2/conv0/weights' : 'set_upconv2.conv1.0.weight',
           'up_sa_layer2/conv1/biases' : 'set_upconv2.conv1.3.bias',
           'up_sa_layer2/conv1/bn/beta' : 'set_upconv2.conv1.4.bias',
           'up_sa_layer2/conv1/bn/gamma' : 'set_upconv2.conv1.4.weight',
           'up_sa_layer2/conv1/bn/moving_mean' : 'set_upconv2.conv1.4.running_mean',
           'up_sa_layer2/conv1/bn/moving_variance' : 'set_upconv2.conv1.4.running_var',
           'up_sa_layer2/conv1/weights' : 'set_upconv2.conv1.3.weight',
           'up_sa_layer2/conv2/biases' : 'set_upconv2.conv1.6.bias',
           'up_sa_layer2/conv2/bn/beta' : 'set_upconv2.conv1.7.bias',
           'up_sa_layer2/conv2/bn/gamma' : 'set_upconv2.conv1.7.weight',
           'up_sa_layer2/conv2/bn/moving_mean' : 'set_upconv2.conv1.7.running_mean',
           'up_sa_layer2/conv2/bn/moving_variance' : 'set_upconv2.conv1.7.running_var',
           'up_sa_layer2/conv2/weights' : 'set_upconv2.conv1.6.weight',
           'up_sa_layer2/post-conv0/biases' : 'set_upconv2.conv2.0.bias',
           'up_sa_layer2/post-conv0/bn/beta' : 'set_upconv2.conv2.1.bias',
           'up_sa_layer2/post-conv0/bn/gamma' : 'set_upconv2.conv2.1.weight',
           'up_sa_layer2/post-conv0/bn/moving_mean' : 'set_upconv2.conv2.1.running_mean',
           'up_sa_layer2/post-conv0/bn/moving_variance' : 'set_upconv2.conv2.1.running_var',
           'up_sa_layer2/post-conv0/weights' : 'set_upconv2.conv2.0.weight',
           'up_sa_layer3/conv0/biases' : 'set_upconv3.conv1.0.bias',
           'up_sa_layer3/conv0/bn/beta' : 'set_upconv3.conv1.1.bias',
           'up_sa_layer3/conv0/bn/gamma' : 'set_upconv3.conv1.1.weight',
           'up_sa_layer3/conv0/bn/moving_mean' : 'set_upconv3.conv1.1.running_mean',
           'up_sa_layer3/conv0/bn/moving_variance' : 'set_upconv3.conv1.1.running_var',
           'up_sa_layer3/conv0/weights' : 'set_upconv3.conv1.0.weight',
           'up_sa_layer3/conv1/biases' : 'set_upconv3.conv1.3.bias',
           'up_sa_layer3/conv1/bn/beta' : 'set_upconv3.conv1.4.bias',
           'up_sa_layer3/conv1/bn/gamma' : 'set_upconv3.conv1.4.weight',
           'up_sa_layer3/conv1/bn/moving_mean' : 'set_upconv3.conv1.4.running_mean',
           'up_sa_layer3/conv1/bn/moving_variance' : 'set_upconv3.conv1.4.running_var',
           'up_sa_layer3/conv1/weights' : 'set_upconv3.conv1.3.weight',
           'up_sa_layer3/conv2/biases' : 'set_upconv3.conv1.6.bias',
           'up_sa_layer3/conv2/bn/beta' : 'set_upconv3.conv1.7.bias',
           'up_sa_layer3/conv2/bn/gamma' : 'set_upconv3.conv1.7.weight',
           'up_sa_layer3/conv2/bn/moving_mean' : 'set_upconv3.conv1.7.running_mean',
           'up_sa_layer3/conv2/bn/moving_variance' : 'set_upconv3.conv1.7.running_var',
           'up_sa_layer3/conv2/weights' : 'set_upconv3.conv1.6.weight',
           'up_sa_layer3/post-conv0/biases' : 'set_upconv3.conv2.0.bias',
           'up_sa_layer3/post-conv0/bn/beta' : 'set_upconv3.conv2.1.bias',
           'up_sa_layer3/post-conv0/bn/gamma' : 'set_upconv3.conv2.1.weight',
           'up_sa_layer3/post-conv0/bn/moving_mean' : 'set_upconv3.conv2.1.running_mean',
           'up_sa_layer3/post-conv0/bn/moving_variance' : 'set_upconv3.conv2.1.running_var',
           'up_sa_layer3/post-conv0/weights' : 'set_upconv3.conv2.0.weight',
          }
mapping=dict([reversed(i) for i in mapping.items()])

state_dict = FlowNet3D().state_dict()
for key, _ in state_dict.items():
    if not state_dict[key].shape:
        continue
    elif len(state_dict[key].shape) == 4:
        state_dict[key][:, :, :, :] = torch.from_numpy(tf_vars[mapping[key]]).permute(3, 2, 0, 1)
        if 'flow_embedding.conv.0' in key:
            temp = state_dict[key][:, -3:, :, :].clone()
            state_dict[key][:, 3:, :, :] = state_dict[key][:, :-3, :, :].clone()
            state_dict[key][:, :3, :, :] = temp
        if 'set_upconv1.conv2.0' in key:
            temp = state_dict[key][:, 512:515, :, :].clone()
            state_dict[key][:, 3:515, :, :] = state_dict[key][:, :512, :, :].clone()
            state_dict[key][:, :3, :, :] = temp
        if 'set_upconv2.conv1.0' in key:
            temp = state_dict[key][:, 256:259, :, :].clone()
            state_dict[key][:, 3:259, :, :] = state_dict[key][:, :256, :, :].clone()
            state_dict[key][:, :3, :, :] = temp
        if 'set_upconv3.conv1.0' in key:
            temp = state_dict[key][:, 256:259, :, :].clone()
            state_dict[key][:, 3:259, :, :] = state_dict[key][:, :256, :, :].clone()
            state_dict[key][:, :3, :, :] = temp
    elif len(state_dict[key].shape) == 3:
        state_dict[key][:, :, :] = torch.from_numpy(tf_vars[mapping[key]]).permute(2, 1, 0)
    elif len(state_dict[key].shape) == 1:
        state_dict[key][:] = torch.from_numpy(tf_vars[mapping[key]])
           
torch.save(state_dict, 'models/net_tf.pth')

In [7]:
%%time

# data
test_set = SceneflowDataset(train=False)
test_loader = DataLoader(test_set,
                         batch_size=16,
                         num_workers=4,
                         pin_memory=True,
                         drop_last=True)

print('test set:', len(test_set), 'samples /', len(test_loader), 'mini-batches')

# model
net = FlowNet3D().cuda()
net.load_state_dict(torch.load('models/net_tf.pth'))
net.eval()

# statistics
loss_sum = 0
epe_sum = 0
acc050_sum = 0
acc010_sum = 0

with torch.no_grad():
    
    # for each mini-batch
    for points1, points2, features1, features2, flow, mask1 in test_loader:
            
        # to GPU
        points1 = points1.cuda(non_blocking=True)
        points2 = points2.cuda(non_blocking=True)
        features1 = features1.cuda(non_blocking=True)
        features2 = features2.cuda(non_blocking=True)
        flow = flow.cuda(non_blocking=True)
        mask1 = mask1.cuda(non_blocking=True)
    
        pred_flow_sum = torch.zeros(16, 3, 2048).cuda(non_blocking=True)
        
        # resample 10 times
        for i in range(10):
            
            perm = torch.randperm(points1.shape[2])
            points1_perm = points1[:, :, perm]
            points2_perm = points2[:, :, perm]
            features1_perm = features1[:, :, perm]
            features2_perm = features2[:, :, perm]

            # forward
            pred_flow = net(points1_perm, points2_perm, features1_perm, features2_perm)
            pred_flow_sum[:, :, perm] += pred_flow
            pred_flow_sum=pred_flow_sum
        
        # statistics
        pred_flow_sum /= 10
        loss = criterion(pred_flow_sum, flow, mask1)
        loss_sum += loss.item()
        epe, acc050, acc010 = error(pred_flow_sum, flow, mask1)
        epe_sum += epe
        acc050_sum += acc050
        acc010_sum += acc010
        
print('mean loss:', loss_sum/len(test_loader))
print('mean epe:', epe_sum/len(test_loader))
print('mean acc050:', acc050_sum/len(test_loader))
print('mean acc010:', acc010_sum/len(test_loader))
    
print('---')

test set: 2007 samples / 125 mini-batches
mean loss: 0.024514434352517128
mean epe: 0.14857300339009882
mean acc050: 0.28733477243061456
mean acc010: 0.6192634478202645
---
CPU times: user 1min 6s, sys: 25.7 s, total: 1min 32s
Wall time: 1min 33s


In [None]:
# parameters
BATCH_SIZE = 16
NUM_POINTS = 2048
NUM_EPOCHS = 150
INIT_LR = 0.001
MIN_LR = 0.00001
STEP_SIZE_LR = 10
GAMMA_LR = 0.7
INIT_BN_MOMENTUM = 0.5
MIN_BN_MOMENTUM = 0.01
STEP_SIZE_BN_MOMENTUM = 10
GAMMA_BN_MOMENTUM = 0.5

# data
train_manager = Manager()
train_cache = train_manager.dict()
train_dataset = SceneflowDataset(npoints=NUM_POINTS, train=True, cache=train_cache)
train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          num_workers=4,
                          shuffle=True,
                          pin_memory=True,
                          drop_last=True)
print('train:', len(train_dataset), '/', len(train_loader))

test_manager = Manager()
test_cache = test_manager.dict()
test_dataset = SceneflowDataset(npoints=NUM_POINTS, train=False, cache=test_cache)
test_loader = DataLoader(test_dataset,
                        batch_size=BATCH_SIZE,
                        num_workers=4,
                        pin_memory=True,
                        drop_last=True)
print('test:', len(test_dataset), '/', len(test_loader))

# net
def init_weights(m):
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.0)
    elif isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1.0)
        m.bias.data.fill_(0.0)

net = FlowNet3D().cuda()
net.apply(init_weights)
print('# parameters: ', parameter_count(net))

# optimizer
optimizer = optim.Adam(net.parameters(), lr=INIT_LR)

# learning rate scheduler
lr_scheduler = ClippedStepLR(optimizer, STEP_SIZE_LR, MIN_LR, GAMMA_LR)

# batch norm momentum scheduler
def update_bn_momentum(epoch):
    for m in net.modules():
        if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
            m.momentum = max(INIT_BN_MOMENTUM * GAMMA_BN_MOMENTUM ** (epoch // STEP_SIZE_BN_MOMENTUM), MIN_BN_MOMENTUM)

# statistics
losses_train = []
losses_test = []

# for num_epochs
for epoch in range(NUM_EPOCHS):
    
    # update batch norm momentum
    update_bn_momentum(epoch)
    
    # train mode
    net.train()
    
    # statistics
    running_loss = 0.0
    torch.cuda.synchronize()
    start_time = time.time()
    
    # for each mini-batch
    for points1, points2, features1, features2, flow, mask1 in train_loader:
        # to GPU
        points1 = points1.cuda(non_blocking=True)
        points2 = points2.cuda(non_blocking=True)
        features1 = features1.cuda(non_blocking=True)
        features2 = features2.cuda(non_blocking=True)
        flow = flow.cuda(non_blocking=True)
        mask1 = mask1.cuda(non_blocking=True)
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        pred_flow = net(points1, points2, features1, features2)
        loss = criterion(pred_flow, flow, mask1)
        loss.backward()
        optimizer.step()
        
        # statistics
        running_loss += loss.item()
    
    torch.cuda.synchronize()
    end_time = time.time()
    
    running_loss /= (len(train_loader))
    
    losses_train.append(running_loss)
    
    # output
    print('Epoch {} (train) -- loss: {:.6f} -- duration (epoch/iteration): {:.4f} min/{:.4f} sec'.format(epoch, running_loss, (end_time-start_time)/60.0, (end_time-start_time)/len(train_loader)))
    
    # validate
    with torch.no_grad():
      
        # eval mode
        net.eval()

        # statistics
        running_loss = 0.0
        torch.cuda.synchronize()
        start_time = time.time()
        
        # for each mini-batch
        for points1, points2, features1, features2, flow, mask1 in test_loader:
            
            # to GPU
            points1 = points1.cuda(non_blocking=True)
            points2 = points2.cuda(non_blocking=True)
            features1 = features1.cuda(non_blocking=True)
            features2 = features2.cuda(non_blocking=True)
            flow = flow.cuda(non_blocking=True)
            mask1 = mask1.cuda(non_blocking=True)

            # forward
            pred_flow = net(points1, points2, features1, features2)
            loss = criterion(pred_flow, flow, mask1)

            # statistics
            running_loss += loss.item()
            
        torch.cuda.synchronize()
        end_time = time.time()

        running_loss /= len(test_loader)

        losses_test.append(running_loss)

        # output
        print('Epoch {} (test) -- loss: {:.6f} -- duration (epoch/iteration): {:.4f} min/{:.4f} sec'.format(epoch, running_loss, (end_time-start_time)/60.0, (end_time-start_time)/len(train_loader)))
        
    # update learning rate
    lr_scheduler.step()
    
    print('---')
    
plt.plot(losses_train)
plt.plot(losses_test)

net = net.cpu()
torch.save(net.state_dict(),'models/net.pth')

In [8]:
%%time

# data
test_set = SceneflowDataset(train=False)
test_loader = DataLoader(test_set,
                         batch_size=16,
                         num_workers=4,
                         pin_memory=True,
                         drop_last=True)

print('test set:', len(test_set), 'samples /', len(test_loader), 'mini-batches')

# model
net = FlowNet3D().cuda()
net.load_state_dict(torch.load('models/net.pth'))
net.eval()

# statistics
loss_sum = 0
epe_sum = 0
acc050_sum = 0
acc010_sum = 0

with torch.no_grad():
    
    # for each mini-batch
    for points1, points2, features1, features2, flow, mask1 in test_loader:
            
        # to GPU
        points1 = points1.cuda(non_blocking=True)
        points2 = points2.cuda(non_blocking=True)
        features1 = features1.cuda(non_blocking=True)
        features2 = features2.cuda(non_blocking=True)
        flow = flow.cuda(non_blocking=True)
        mask1 = mask1.cuda(non_blocking=True)
    
        pred_flow_sum = torch.zeros(16, 3, 2048).cuda(non_blocking=True)
        
        # resample 10 times
        for i in range(10):
            
            perm = torch.randperm(points1.shape[2])
            points1_perm = points1[:, :, perm]
            points2_perm = points2[:, :, perm]
            features1_perm = features1[:, :, perm]
            features2_perm = features2[:, :, perm]

            # forward
            pred_flow = net(points1_perm, points2_perm, features1_perm, features2_perm)
            pred_flow_sum[:, :, perm] += pred_flow
        
        # statistics
        pred_flow_sum /= 10
        loss = criterion(pred_flow_sum, flow, mask1)
        loss_sum += loss.item()
        epe, acc050, acc010 = error(pred_flow_sum, flow, mask1)
        epe_sum += epe
        acc050_sum += acc050
        acc010_sum += acc010
        
print('mean loss:', loss_sum/len(test_loader))
print('mean epe:', epe_sum/len(test_loader))
print('mean acc050:', acc050_sum/len(test_loader))
print('mean acc010:', acc010_sum/len(test_loader))
    
print('---')

test set: 2007 samples / 125 mini-batches
mean loss: 0.022257591910660266
mean epe: 0.1463848694865201
mean acc050: 0.28563897619579187
mean acc010: 0.622624611406866
---
CPU times: user 1min 4s, sys: 25 s, total: 1min 29s
Wall time: 1min 30s
