In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from data import ModelNet40,download,load_data
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from geomloss import SamplesLoss
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np
from matplotlib import pyplot as plt
import pickle
import warnings
import math
import time
from datetime import datetime
from collections import OrderedDict
from scipy.optimize import linear_sum_assignment
from torch import jit
%matplotlib inline
import os
from torch.utils.data import Dataset, TensorDataset

In [14]:
train_losses = []
test_results_exp = []
test_accs = []

In [16]:
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 2 every 30 epochs"""
    lrt = lr * (0.5 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lrt

In [2]:
device = torch.device("cuda:2")

In [36]:
data = ModelNet40(num_points=2048, partition='train', gaussian_noise=False,
                       unseen=False, factor=4)
data_test = ModelNet40(num_points=2048, partition='test', gaussian_noise=False,
                       unseen=False, factor=4)
train_loader = DataLoader(data, batch_size=128, shuffle=True, drop_last=True)
test_loader = DataLoader(data_test, batch_size=1, shuffle=True, drop_last=True)

  f = h5py.File(h5_name)


In [39]:
def train_eval_icp(model,device, train_loader,test_loader, optimizer,criterion, epoch, eval_mode='on'):
    global best
    model.train()
    final_cal = nn.Sigmoid()
    train_error_logs = []
    t1 = datetime.now()
    test_loss_log = []
    test_acc_log = []
    train_corrects = 0
    
    for batch_idx, (src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba) in enumerate(train_loader):
        src = src.to(device)
        target = target.to(device)
        rotation_ab = rotation_ab.to(device)
        translation_ab = translation_ab.to(device)
        rotation_ba = rotation_ba.to(device)
        translation_ba = translation_ba.to(device)
        optimizer.zero_grad()
        output_train = model(src,target,rotation_ab,translation_ab)
        loss = criterion(output_train, target)
        loss.backward()
        optimizer.step()
        #pred_train = output_train.argmax(dim=1, keepdim=True) # get the index of the max log-probability
        #train_corrects += pred_train.eq(target.view_as(pred_train)).sum().item()
        train_error_logs.append(loss.item())
        
        if (batch_idx+0) % 20 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tTime: {:.2f}'.format(
                epoch, (batch_idx+0)* len(target), len(train_loader.dataset),
                100. * (batch_idx+0) / len(train_loader), loss,(datetime.now()-t1).total_seconds()))

        
    print('Train Epoch: {} Accuracy: {}/{} ({:.2f}%)\n'.format(
                epoch, train_corrects, len(train_loader.dataset),
                100. * train_corrects / len(train_loader.dataset)))
    
    model.eval()
    test_loss = 0
    
    with torch.no_grad():
        for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba in enumerate(test_loader):
            src = src.to(device)
            target = target.to(device)
            rotation_ab = rotation_ab.to(device)
            translation_ab = translation_ab.to(device)
            rotation_ba = rotation_ba.to(device)
            translation_ba = translation_ba.to(device)
            output_test = model(src,target,rotation_ab,translation_ab)
            test_loss += criterion(output_test, target_test).item()  # sum up batch loss
    test_loss /= len(test_loader.dataset)
    test_loss_log.append(test_loss)
    print('Test set: Average loss: {:.8f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct_test, len(test_loader.dataset),
            100. * correct_test / len(test_loader.dataset)))
    return train_error_logs,test_loss_log,test_acc_log

In [5]:
class PointNet(nn.Module):
    def __init__(self, emb_dims=512):
        super(PointNet, self).__init__()
        self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.conv3 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.conv4 = nn.Conv1d(64, 128, kernel_size=1, bias=False)
        self.conv5 = nn.Conv1d(128, emb_dims, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(emb_dims)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        return x

In [6]:
class CPNet(nn.Module):
    def __init__(self, emb_dims=512):
        super(CPNet, self).__init__()
        self.base_net = PointNet()
        self.loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)

        
    def forward(self, x1, x2, R, T):
        
        trans = torch.matmul(R, x1) + T.unsqueeze(2)
        emb1 = self.base_net(trans)
        emb2 = self.base_net(x2)
        
        dist = self.loss(emb1,emb2)

        return dist

In [32]:
torch.Tensor(1,3,3)

tensor([[[ 1.3452e-43,  0.0000e+00,  8.9683e-44],
         [ 0.0000e+00, -1.1594e+09,  3.0635e-41],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00]]])

In [35]:
class CPEvalNet(nn.Module):
    def __init__(self, base_net, emb_dims=512):
        super(CPEvalNet, self).__init__()
        self.base_net = PointNet()
        self.base_net.load_state_dict(base_net.state_dict())
        for param in self.base_net:
            param.require_grad = False
        self.rotation = nn.Parameter(torch.Tensor(1,3,3))
        self.translation = nn.Parameter(torch.Tensor(1,3))
        nn.init.kaiming_uniform_(self.rotation, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.translation, a=math.sqrt(5))
        self.loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)

        
    def forward(self, x1, x2):
        
        trans = torch.matmul(self.rotation, x1) + self.translation.unsqueeze(2)
        emb1 = self.base_net(trans)
        emb2 = self.base_net(x2)
        
        dist = self.loss(emb1,emb2)

        return dist

In [7]:
def dummy_loss(output, target):
    loss = torch.mean(output)
    return loss

In [25]:
lr = 1e-3
momentum = 0.9
weight_decay = 5e-4

In [40]:
model = CPNet().to(device)

In [41]:
criterion = dummy_loss
optimizer = torch.optim.SGD(model.parameters(), lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
best=0

In [42]:
train_time= []
train_loss = []
train_losses.append(train_loss)
tests = []
test_acc = []
test_results_exp.append(tests)
test_accs.append(test_acc)
ratio = 0
for epoch in range(1,600):
    adjust_learning_rate(optimizer, epoch)
    t1 = datetime.now()
    train_error,test_error,test_acc_this = train_icp(model,device, train_loader, optimizer,criterion,epoch,'on')
    train_loss.extend(train_error)
    tests.extend(test_error)
    test_acc.extend(test_acc_this)
    train_time.append((datetime.now()-t1).total_seconds())
    print((datetime.now()-t1).total_seconds())

Train Epoch: 1 Accuracy: 0/9840 (0.00%)

39.347571
Train Epoch: 2 Accuracy: 0/9840 (0.00%)

38.072406
Train Epoch: 3 Accuracy: 0/9840 (0.00%)

37.519726
Train Epoch: 4 Accuracy: 0/9840 (0.00%)

37.115867
Train Epoch: 5 Accuracy: 0/9840 (0.00%)

36.978782
Train Epoch: 6 Accuracy: 0/9840 (0.00%)

36.597896
Train Epoch: 7 Accuracy: 0/9840 (0.00%)

36.506719
Train Epoch: 8 Accuracy: 0/9840 (0.00%)

36.325491
Train Epoch: 9 Accuracy: 0/9840 (0.00%)

36.198596
Train Epoch: 10 Accuracy: 0/9840 (0.00%)

36.202035
Train Epoch: 11 Accuracy: 0/9840 (0.00%)

36.230876
Train Epoch: 12 Accuracy: 0/9840 (0.00%)

36.21919
Train Epoch: 13 Accuracy: 0/9840 (0.00%)

36.097763
Train Epoch: 14 Accuracy: 0/9840 (0.00%)

36.109575
Train Epoch: 15 Accuracy: 0/9840 (0.00%)

36.307178
Train Epoch: 16 Accuracy: 0/9840 (0.00%)

35.962115
Train Epoch: 17 Accuracy: 0/9840 (0.00%)

35.818317
Train Epoch: 18 Accuracy: 0/9840 (0.00%)

35.806736
Train Epoch: 19 Accuracy: 0/9840 (0.00%)

35.809771
Train Epoch: 20 Accura

Train Epoch: 29 Accuracy: 0/9840 (0.00%)

35.7524
Train Epoch: 30 Accuracy: 0/9840 (0.00%)

35.808145
Train Epoch: 31 Accuracy: 0/9840 (0.00%)

35.730218
Train Epoch: 32 Accuracy: 0/9840 (0.00%)

35.714265
Train Epoch: 33 Accuracy: 0/9840 (0.00%)

35.698686
Train Epoch: 34 Accuracy: 0/9840 (0.00%)

35.695941
Train Epoch: 35 Accuracy: 0/9840 (0.00%)

35.846071
Train Epoch: 36 Accuracy: 0/9840 (0.00%)

35.705099
Train Epoch: 37 Accuracy: 0/9840 (0.00%)

35.70653
Train Epoch: 38 Accuracy: 0/9840 (0.00%)

35.52013
Train Epoch: 39 Accuracy: 0/9840 (0.00%)

35.341958
Train Epoch: 40 Accuracy: 0/9840 (0.00%)

35.524058
Train Epoch: 41 Accuracy: 0/9840 (0.00%)

35.728774
Train Epoch: 42 Accuracy: 0/9840 (0.00%)

35.582293
Train Epoch: 43 Accuracy: 0/9840 (0.00%)

35.642189
Train Epoch: 44 Accuracy: 0/9840 (0.00%)

35.58883
Train Epoch: 45 Accuracy: 0/9840 (0.00%)

35.690308
Train Epoch: 46 Accuracy: 0/9840 (0.00%)

35.705753
Train Epoch: 47 Accuracy: 0/9840 (0.00%)

35.545765
Train Epoch: 48 A

Train Epoch: 57 Accuracy: 0/9840 (0.00%)

35.517275
Train Epoch: 58 Accuracy: 0/9840 (0.00%)

35.527514
Train Epoch: 59 Accuracy: 0/9840 (0.00%)

35.47355
Train Epoch: 60 Accuracy: 0/9840 (0.00%)

35.479712
Train Epoch: 61 Accuracy: 0/9840 (0.00%)

35.613811
Train Epoch: 62 Accuracy: 0/9840 (0.00%)

35.498817
Train Epoch: 63 Accuracy: 0/9840 (0.00%)

35.534423
Train Epoch: 64 Accuracy: 0/9840 (0.00%)

35.541741
Train Epoch: 65 Accuracy: 0/9840 (0.00%)

35.613126
Train Epoch: 66 Accuracy: 0/9840 (0.00%)

35.451705
Train Epoch: 67 Accuracy: 0/9840 (0.00%)

35.488816
Train Epoch: 68 Accuracy: 0/9840 (0.00%)

35.555874
Train Epoch: 69 Accuracy: 0/9840 (0.00%)

35.570939
Train Epoch: 70 Accuracy: 0/9840 (0.00%)

35.494414
Train Epoch: 71 Accuracy: 0/9840 (0.00%)

35.56013
Train Epoch: 72 Accuracy: 0/9840 (0.00%)

35.568737
Train Epoch: 73 Accuracy: 0/9840 (0.00%)

35.355983
Train Epoch: 74 Accuracy: 0/9840 (0.00%)

35.338218
Train Epoch: 75 Accuracy: 0/9840 (0.00%)

35.447361
Train Epoch: 7

Train Epoch: 86 Accuracy: 0/9840 (0.00%)

35.552937
Train Epoch: 87 Accuracy: 0/9840 (0.00%)

35.473067
Train Epoch: 88 Accuracy: 0/9840 (0.00%)

35.47008
Train Epoch: 89 Accuracy: 0/9840 (0.00%)

35.514947
Train Epoch: 90 Accuracy: 0/9840 (0.00%)

35.410664
Train Epoch: 91 Accuracy: 0/9840 (0.00%)

35.459141
Train Epoch: 92 Accuracy: 0/9840 (0.00%)

35.453251
Train Epoch: 93 Accuracy: 0/9840 (0.00%)

35.552693
Train Epoch: 94 Accuracy: 0/9840 (0.00%)

35.585656
Train Epoch: 95 Accuracy: 0/9840 (0.00%)

35.403199
Train Epoch: 96 Accuracy: 0/9840 (0.00%)

35.486714
Train Epoch: 97 Accuracy: 0/9840 (0.00%)

35.585235
Train Epoch: 98 Accuracy: 0/9840 (0.00%)

35.531487
Train Epoch: 99 Accuracy: 0/9840 (0.00%)

35.436858
Train Epoch: 100 Accuracy: 0/9840 (0.00%)

35.486492
Train Epoch: 101 Accuracy: 0/9840 (0.00%)

35.404701
Train Epoch: 102 Accuracy: 0/9840 (0.00%)

35.408989
Train Epoch: 103 Accuracy: 0/9840 (0.00%)

35.48754
Train Epoch: 104 Accuracy: 0/9840 (0.00%)

35.494833
Train Epo

Train Epoch: 114 Accuracy: 0/9840 (0.00%)

35.396639
Train Epoch: 115 Accuracy: 0/9840 (0.00%)

35.388188
Train Epoch: 116 Accuracy: 0/9840 (0.00%)

35.497035
Train Epoch: 117 Accuracy: 0/9840 (0.00%)

35.464299
Train Epoch: 118 Accuracy: 0/9840 (0.00%)

35.423378
Train Epoch: 119 Accuracy: 0/9840 (0.00%)

35.463594
Train Epoch: 120 Accuracy: 0/9840 (0.00%)

35.419398
Train Epoch: 121 Accuracy: 0/9840 (0.00%)

35.459756
Train Epoch: 122 Accuracy: 0/9840 (0.00%)

35.516553
Train Epoch: 123 Accuracy: 0/9840 (0.00%)

35.724804
Train Epoch: 124 Accuracy: 0/9840 (0.00%)

35.462614
Train Epoch: 125 Accuracy: 0/9840 (0.00%)

35.419053
Train Epoch: 126 Accuracy: 0/9840 (0.00%)

35.488481
Train Epoch: 127 Accuracy: 0/9840 (0.00%)

35.389201
Train Epoch: 128 Accuracy: 0/9840 (0.00%)

35.412535
Train Epoch: 129 Accuracy: 0/9840 (0.00%)

35.454899
Train Epoch: 130 Accuracy: 0/9840 (0.00%)

35.39836
Train Epoch: 131 Accuracy: 0/9840 (0.00%)

35.429829
Train Epoch: 132 Accuracy: 0/9840 (0.00%)

35.4

Train Epoch: 142 Accuracy: 0/9840 (0.00%)

35.472426
Train Epoch: 143 Accuracy: 0/9840 (0.00%)

35.36946
Train Epoch: 144 Accuracy: 0/9840 (0.00%)

35.332554
Train Epoch: 145 Accuracy: 0/9840 (0.00%)

35.335777
Train Epoch: 146 Accuracy: 0/9840 (0.00%)

35.483594
Train Epoch: 147 Accuracy: 0/9840 (0.00%)

35.503085
Train Epoch: 148 Accuracy: 0/9840 (0.00%)

35.455375
Train Epoch: 149 Accuracy: 0/9840 (0.00%)

35.390075
Train Epoch: 150 Accuracy: 0/9840 (0.00%)

35.504133
Train Epoch: 151 Accuracy: 0/9840 (0.00%)

35.361838
Train Epoch: 152 Accuracy: 0/9840 (0.00%)

35.428978
Train Epoch: 153 Accuracy: 0/9840 (0.00%)

35.382729
Train Epoch: 154 Accuracy: 0/9840 (0.00%)

35.399612
Train Epoch: 155 Accuracy: 0/9840 (0.00%)

35.445037
Train Epoch: 156 Accuracy: 0/9840 (0.00%)

35.444106
Train Epoch: 157 Accuracy: 0/9840 (0.00%)

35.458645
Train Epoch: 158 Accuracy: 0/9840 (0.00%)

35.472643
Train Epoch: 159 Accuracy: 0/9840 (0.00%)

35.36906
Train Epoch: 160 Accuracy: 0/9840 (0.00%)

35.45

KeyboardInterrupt: 

In [29]:
for batch_idx, (src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba) in enumerate(train_loader):
    break

In [34]:
translation_ab.shape

torch.Size([128, 3])