In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from data import ModelNet40,download,load_data
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from geomloss import SamplesLoss
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np
from matplotlib import pyplot as plt
import pickle
import warnings
import math
import time
from datetime import datetime
from collections import OrderedDict
from scipy.optimize import linear_sum_assignment
from torch import jit
%matplotlib inline
import os
from torch.utils.data import Dataset, TensorDataset

In [2]:
train_losses = []
test_results_exp = []
test_accs = []

In [3]:
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 2 every 30 epochs"""
    lrt = lr * (0.5 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lrt

In [4]:
device = torch.device("cuda:2")

In [27]:
data = ModelNet40(num_points=2048, partition='train', gaussian_noise=False,
                       unseen=False, factor=4)
data_test = ModelNet40(num_points=2048, partition='test', gaussian_noise=False,
                       unseen=False, factor=4)
train_loader = DataLoader(data, batch_size=128, shuffle=True, drop_last=True)
test_loader = DataLoader(data_test, batch_size=128, shuffle=True, drop_last=True)

In [32]:
def train_eval_icp(model,device, train_loader,test_loader, optimizer,criterion, epoch, eval_mode='on'):
    global best
    model.train()
    final_cal = nn.Sigmoid()
    train_error_logs = []
    t1 = datetime.now()
    test_loss_log = []
    test_acc_log = []
    train_corrects = 0
    
    for batch_idx, (src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba) in enumerate(train_loader):
        src = src.to(device)
        target = target.to(device)
        rotation_ab = rotation_ab.to(device)
        translation_ab = translation_ab.to(device)
        rotation_ba = rotation_ba.to(device)
        translation_ba = translation_ba.to(device)
        optimizer.zero_grad()
        output_train = model(src,target,rotation_ab,translation_ab)
        loss = criterion(output_train, target)
        loss.backward()
        optimizer.step()
        #pred_train = output_train.argmax(dim=1, keepdim=True) # get the index of the max log-probability
        #train_corrects += pred_train.eq(target.view_as(pred_train)).sum().item()
        train_error_logs.append(loss.item())
        
        if (batch_idx+0) % 20 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tTime: {:.2f}'.format(
                epoch, (batch_idx+0)* len(target), len(train_loader.dataset),
                100. * (batch_idx+0) / len(train_loader), loss,(datetime.now()-t1).total_seconds()))

        
    print('Train Epoch: {} Accuracy: {}/{} ({:.2f}%)\n'.format(
                epoch, train_corrects, len(train_loader.dataset),
                100. * train_corrects / len(train_loader.dataset)))
    
    model.eval()
    test_loss = 0
    correct_test = 0
    with torch.no_grad():
        for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba in test_loader:
            src = src.to(device)
            target = target.to(device)
            rotation_ab = rotation_ab.to(device)
            translation_ab = translation_ab.to(device)
            rotation_ba = rotation_ba.to(device)
            translation_ba = translation_ba.to(device)
            output_test = model(src,target,rotation_ab,translation_ab)
            test_loss += criterion(output_test, target).item()*len(src)  # sum up batch loss
    test_loss /= len(test_loader.dataset)
    test_loss_log.append(test_loss)
    print('Test set: Average loss: {:.8f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct_test, len(test_loader.dataset),
            100. * correct_test / len(test_loader.dataset)))
    return train_error_logs,test_loss_log,test_acc_log

In [7]:
class PointNet(nn.Module):
    def __init__(self, emb_dims=512):
        super(PointNet, self).__init__()
        self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.conv3 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.conv4 = nn.Conv1d(64, 128, kernel_size=1, bias=False)
        self.conv5 = nn.Conv1d(128, emb_dims, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(emb_dims)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        return x

In [8]:
class CPNet(nn.Module):
    def __init__(self, emb_dims=512):
        super(CPNet, self).__init__()
        self.base_net = PointNet()
        self.loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)

        
    def forward(self, x1, x2, R, T):
        
        trans = torch.matmul(R, x1) + T.unsqueeze(2)
        emb1 = self.base_net(trans)
        emb2 = self.base_net(x2)
        
        dist = self.loss(emb1,emb2)

        return dist

In [9]:
torch.Tensor(1,3,3)

tensor([[[-2.0961e+37,  3.0921e-41,  9.3664e-24],
         [ 4.5804e-41,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00]]])

In [10]:
class CPEvalNet(nn.Module):
    def __init__(self, base_net, emb_dims=512):
        super(CPEvalNet, self).__init__()
        self.base_net = PointNet()
        self.base_net.load_state_dict(base_net.state_dict())
        for param in self.base_net:
            param.require_grad = False
        self.rotation = nn.Parameter(torch.Tensor(1,3,3))
        self.translation = nn.Parameter(torch.Tensor(1,3))
        nn.init.kaiming_uniform_(self.rotation, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.translation, a=math.sqrt(5))
        self.loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)

        
    def forward(self, x1, x2):
        
        trans = torch.matmul(self.rotation, x1) + self.translation.unsqueeze(2)
        emb1 = self.base_net(trans)
        emb2 = self.base_net(x2)
        
        dist = self.loss(emb1,emb2)

        return dist

In [11]:
def dummy_loss(output, target):
    loss = torch.mean(output)
    return loss

In [28]:
lr = 1e-3
momentum = 0.9
weight_decay = 5e-4

In [33]:
model = CPNet().to(device)

In [34]:
criterion = dummy_loss
optimizer = torch.optim.SGD(model.parameters(), lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
best=0

In [35]:
train_time= []
train_loss = []
train_losses.append(train_loss)
tests = []
test_acc = []
test_results_exp.append(tests)
test_accs.append(test_acc)
ratio = 0
for epoch in range(1,600):
    adjust_learning_rate(optimizer, epoch)
    t1 = datetime.now()
    train_error,test_error,test_acc_this = train_eval_icp(model,device, train_loader, test_loader, optimizer,criterion,epoch,'on')
    train_loss.extend(train_error)
    tests.extend(test_error)
    test_acc.extend(test_acc_this)
    train_time.append((datetime.now()-t1).total_seconds())
    print((datetime.now()-t1).total_seconds())

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [29]:
for batch_idx, (src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba) in enumerate(train_loader):
    break

In [34]:
translation_ab.shape

torch.Size([128, 3])