In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from data_with_pert import ModelNet40_pert,download,load_data
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from geomloss import SamplesLoss
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np
from matplotlib import pyplot as plt
import pickle
import warnings
import math
import time
from datetime import datetime
from collections import OrderedDict
from scipy.optimize import linear_sum_assignment
from torch import jit
%matplotlib inline
import os
from torch.utils.data import Dataset, TensorDataset
import importlib

In [2]:
train_losses = []
test_results_exp = []
test_accs = []

In [3]:
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 2 every 30 epochs"""
    lrt = lr * (0.5 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lrt

In [4]:
def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
    return torch.index_select(a, dim, order_index)

In [5]:
device = torch.device("cuda:7")

In [6]:
M=10
batch_size= 64
data = ModelNet40_pert(num_points=1000, partition='train', gaussian_noise=False,
                       unseen=False, factor=4,n_pert=M)
data_test = ModelNet40_pert(num_points=1000, partition='test', gaussian_noise=False,
                       unseen=False, factor=4,n_pert=M)
train_loader = DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(data_test, batch_size=batch_size, shuffle=True, drop_last=True)
rl_loader = DataLoader(data_test, batch_size=1, shuffle=True, drop_last=True)

  f = h5py.File(h5_name)


In [7]:
def train_eval_icp(model,device, train_loader,test_loader, optimizer,criterion, epoch, eval_mode='on'):
    global best
    model.train()
    final_cal = nn.Sigmoid()
    train_error_logs = []
    t1 = datetime.now()
    test_loss_log = []
    test_acc_log = []
    train_corrects = 0
    
    for batch_idx, (src, target, r_ab, t_ab, _,_,_,_,r_ab_pert,t_ab_pert) in enumerate(train_loader):
        src = src.to(device)
        target = target.to(device)
        r_ab = r_ab.to(device)
        t_ab = t_ab.to(device)
        r_ab_pert = r_ab_pert.to(device)
        t_ab_pert = t_ab_pert.to(device)
        optimizer.zero_grad()
        dist, dist_pert = model(src,target,r_ab,t_ab,r_ab_pert,t_ab_pert)
        loss = criterion(dist,dist_pert, target)
        loss.backward()
        optimizer.step()
        #pred_train = output_train.argmax(dim=1, keepdim=True) # get the index of the max log-probability
        #train_corrects += pred_train.eq(target.view_as(pred_train)).sum().item()
        train_error_logs.append(loss.item())
        
        if (batch_idx+0) % 20 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tTime: {:.2f}'.format(
                epoch, (batch_idx+0)* len(target), len(train_loader.dataset),
                100. * (batch_idx+0) / len(train_loader), loss,(datetime.now()-t1).total_seconds()))

        
    print('Train Epoch: {} Accuracy: {}/{} ({:.2f}%)\n'.format(
                epoch, train_corrects, len(train_loader.dataset),
                100. * train_corrects / len(train_loader.dataset)))
    '''
    model.eval()
    test_loss = 0
    correct_test = 0
    with torch.no_grad():
        for src, target, rotation_ab, translation_ab, rotation_ba, translation_ba, euler_ab, euler_ba in test_loader:
            src = src.to(device)
            target = target.to(device)
            rotation_ab = rotation_ab.to(device)
            translation_ab = translation_ab.to(device)
            rotation_ba = rotation_ba.to(device)
            translation_ba = translation_ba.to(device)
            output_test = model(src,target,rotation_ab,translation_ab)
            test_loss += criterion(output_test, target).item()*len(src)  # sum up batch loss
    test_loss /= len(test_loader.dataset)
    test_loss_log.append(test_loss)
    print('Test set: Average loss: {:.8f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct_test, len(test_loader.dataset),
            100. * correct_test / len(test_loader.dataset)))
    '''
    return train_error_logs,test_loss_log,test_acc_log

In [8]:
def cal_rt_loss(ra,ta,rb,tb):
    r_loss = F.mse_loss(ra,rb).item()
    t_loss = F.mse_loss(ta,tb).item()
    return r_loss, t_loss

In [9]:
class PointNet(nn.Module):
    def __init__(self, emb_dims=512):
        super(PointNet, self).__init__()
        self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.conv3 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.conv4 = nn.Conv1d(64, 128, kernel_size=1, bias=False)
        self.conv5 = nn.Conv1d(128, emb_dims, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(emb_dims)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        return x

In [10]:
class CPNet(nn.Module):
    def __init__(self, emb_dims=512,split = 10):
        super(CPNet, self).__init__()
        self.split = split
        self.base_net = PointNet()
        self.loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)

        
    def forward(self, x1, x2, R, T, R_pert, T_pert):
        
        trans = torch.matmul(R, x1) + T.unsqueeze(2)
        emb1 = self.base_net(trans)
        emb2 = self.base_net(x2)
        
        dist = self.loss(emb1,emb2)
        
        pert_out = torch.matmul(R_pert.permute(1,0,2,3), x1) + T_pert.permute(1,0,2).unsqueeze(3)
        pert_out = pert_out.reshape(-1, pert_out.shape[-2],pert_out.shape[-1])
        pert_emb1 = self.base_net(pert_out)
        pert_emb2 = tile(emb2,0,self.split)
        pert_dist = self.loss(pert_emb1,pert_emb2)/self.split
        pert_dist = pert_dist.reshape(self.split, emb2.shape[0]).sum(0)
        

        return dist,pert_dist

In [11]:
class CPEvalNet(nn.Module):
    def __init__(self, base_net, emb_dims=512):
        super(CPEvalNet, self).__init__()
        self.base_net = PointNet()
        self.base_net.load_state_dict(base_net.state_dict())
        for param in self.base_net.parameters():
            param.require_grad = False
        self.rotation = nn.Parameter(torch.Tensor(1,3,3))
        self.translation = nn.Parameter(torch.Tensor(1,3))
        nn.init.kaiming_uniform_(self.rotation, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.translation, a=math.sqrt(5))
        self.loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)

        
    def forward(self, x1, x2):
        
        trans = torch.matmul(self.rotation, x1) + self.translation.unsqueeze(2)
        emb1 = self.base_net(trans)
        emb2 = self.base_net(x2)
        
        dist = self.loss(emb1,emb2)

        return dist

In [12]:
def combine_loss(dist, pert_dist, target, alpha = 1, beta = 1, eps = 0.005, eta = 0.005):
    loss = torch.mean(dist)
    
    diff = dist-pert_dist
    dummy = torch.zeros_like(diff)
    pert_loss = F.hinge_embedding_loss(diff,dummy,margin = eps)
    
    return loss+alpha*pert_loss

In [13]:
lr = 1e-3
momentum = 0.9
weight_decay = 5e-4

In [14]:
model = CPNet().to(device)

In [15]:
criterion = combine_loss
optimizer = torch.optim.SGD(model.parameters(), lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
best=0

In [None]:
train_time= []
train_loss = []
train_losses.append(train_loss)
tests = []
test_acc = []
test_results_exp.append(tests)
test_accs.append(test_acc)
ratio = 0
for epoch in range(1,600):
    adjust_learning_rate(optimizer, epoch)
    t1 = datetime.now()
    train_error,test_error,test_acc_this = train_eval_icp(model,device, train_loader, test_loader, optimizer,criterion,epoch,'on')
    train_loss.extend(train_error)
    tests.extend(test_error)
    test_acc.extend(test_acc_this)
    train_time.append((datetime.now()-t1).total_seconds())
    print((datetime.now()-t1).total_seconds())

Train Epoch: 1 Accuracy: 0/9840 (0.00%)

227.928147


In [20]:
torch.cuda.empty_cache()