In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from data import ModelNet40,download,load_data
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from geomloss import SamplesLoss
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np
from matplotlib import pyplot as plt
import pickle
import warnings
import math
import time
from datetime import datetime
from collections import OrderedDict
from scipy.optimize import linear_sum_assignment
from torch import jit
%matplotlib inline
import os
from torch.utils.data import Dataset, TensorDataset
import importlib

In [2]:
train_losses = []
test_results_exp = []
test_accs = []

In [3]:
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 2 every 30 epochs"""
    lrt = lr * (0.5 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lrt

In [4]:
device = torch.device("cuda:6")

In [41]:
data = ModelNet40(num_points=2048, partition='train', gaussian_noise=False,
                       unseen=False, factor=4)
data_test = ModelNet40(num_points=2048, partition='test', gaussian_noise=False,
                       unseen=True, factor=4)
train_loader = DataLoader(data, batch_size=128, shuffle=True, drop_last=True)
test_loader = DataLoader(data_test, batch_size=128, shuffle=True, drop_last=True)
rl_loader = DataLoader(data_test, batch_size=1, shuffle=True, drop_last=True)

In [6]:
def cal_rt_loss(ra,ta,rb,tb):
    r_loss = F.mse_loss(ra,rb).item()
    t_loss = F.mse_loss(ta,tb).item()
    return r_loss, t_loss

In [7]:
class PointNet(nn.Module):
    def __init__(self, emb_dims=512):
        super(PointNet, self).__init__()
        self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.conv3 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.conv4 = nn.Conv1d(64, 128, kernel_size=1, bias=False)
        self.conv5 = nn.Conv1d(128, emb_dims, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(emb_dims)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        return x

In [8]:
class CPNet(nn.Module):
    def __init__(self, emb_dims=512,split = 10):
        super(CPNet, self).__init__()
        self.split = split
        self.base_net = PointNet()
        self.loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)

        
    def forward(self, x1, x2, R, T, R_pert, T_pert):
        
        trans = torch.matmul(R, x1) + T.unsqueeze(2)
        emb1 = self.base_net(trans)
        emb2 = self.base_net(x2)
        
        dist = self.loss(emb1,emb2)
        
        pert_out = torch.matmul(R_pert.permute(1,0,2,3), x1) + T_pert.permute(1,0,2).unsqueeze(3)
        pert_out = pert_out.reshape(-1, pert_out.shape[-2],pert_out.shape[-1])
        pert_emb1 = self.base_net(pert_out)
        pert_emb2 = tile(emb2,0,self.split)
        pert_dist = self.loss(pert_emb1,pert_emb2)/self.split
        pert_dist = pert_dist.reshape(self.split, emb2.shape[0]).sum(0)
        

        return dist,pert_dist

In [9]:
class CPEvalNet(nn.Module):
    def __init__(self, base_net, emb_dims=512):
        super(CPEvalNet, self).__init__()
        self.base_net = PointNet()
        self.base_net.load_state_dict(base_net.state_dict())
        for param in self.base_net.parameters():
            param.require_grad = False
        self.rotation = nn.Parameter(torch.Tensor(1,3,3))
        self.translation = nn.Parameter(torch.Tensor(1,3))
        nn.init.kaiming_uniform_(self.rotation, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.translation, a=math.sqrt(5))
        self.loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)

        
    def forward(self, x1, x2):
        
        trans = torch.matmul(self.rotation, x1) + self.translation.unsqueeze(2)
        emb1 = self.base_net(trans)
        emb2 = self.base_net(x2)
        
        dist = self.loss(emb1,emb2)

        return dist

In [13]:
base_model = torch.load("pert_model.mdl").to(device)

In [14]:
def dummy_loss(output, target):
    loss = torch.mean(output)
    return loss

In [15]:
def rt_learning(model, device, src, target, rotation_ab, translation_ab, optimizer,criterion,epoch,n_batch = 40):
    r_losses = []
    t_losses = []
    for _ in range(n_batch):
        model.train()
        t1 = datetime.now()
        src = src.to(device)
        target = target.to(device)
        rotation_ab = rotation_ab.to(device)
        translation_ab = translation_ab.to(device)
        optimizer.zero_grad()
        output_train = model(src,target)
        loss = criterion(output_train, target)
        loss.backward()
        optimizer.step()
        r_loss,t_loss = cal_rt_loss(model.rotation, model.translation,rotation_ab, translation_ab)
        r_losses.append(r_loss)
        t_losses.append(t_loss)
    return r_losses, t_losses

In [16]:
r_losses = []
t_losses = []

In [44]:
model_eval = CPEvalNet(base_model.base_net).to(device)

In [45]:
criterion = dummy_loss
lr = 1e-1
momentum=0.9
weight_decay = 5e-4
optimizer = torch.optim.SGD(model_eval.parameters(), lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
best=0

In [42]:
for src_rl, target_rl, rotation_ab_rl, translation_ab_rl, rotation_ba_rl, translation_ba_rl,_,_ in rl_loader:
    break

In [46]:
train_time= []
r_loss = []
t_loss = []
r_losses.append(r_loss)
t_losses.append(t_loss)
ratio = 0
for epoch in range(1,600):
    #adjust_learning_rate(optimizer, epoch)
    t1 = datetime.now()
    r_batch_loss, t_batch_loss = rt_learning(model_eval, device, src_rl, target_rl, rotation_ab_rl, translation_ab_rl, 
                                     optimizer,criterion,epoch,n_batch=40)
    r_loss.extend(r_batch_loss)
    t_loss.extend(t_batch_loss)
    train_time.append((datetime.now()-t1).total_seconds())

    print('Train Epoch: {} \tRotation_Loss: {:.6f}\tTranslation_Loss: {:.6f}\tTime: {:.2f}'.format(
        epoch, r_batch_loss[-1],t_batch_loss[-1],(datetime.now()-t1).total_seconds()))

Train Epoch: 1 	Rotation_Loss: 143.610825	Translation_Loss: 0.272337	Time: 1.02
Train Epoch: 2 	Rotation_Loss: 142.305649	Translation_Loss: 0.268373	Time: 0.82
Train Epoch: 3 	Rotation_Loss: 136.947861	Translation_Loss: 0.264535	Time: 0.78
Train Epoch: 4 	Rotation_Loss: 131.736603	Translation_Loss: 0.260830	Time: 0.76
Train Epoch: 5 	Rotation_Loss: 126.726303	Translation_Loss: 0.257256	Time: 0.77
Train Epoch: 6 	Rotation_Loss: 121.910072	Translation_Loss: 0.253806	Time: 0.77
Train Epoch: 7 	Rotation_Loss: 117.280373	Translation_Loss: 0.250476	Time: 0.82
Train Epoch: 8 	Rotation_Loss: 112.829865	Translation_Loss: 0.247263	Time: 0.91
Train Epoch: 9 	Rotation_Loss: 108.551552	Translation_Loss: 0.244160	Time: 0.81
Train Epoch: 10 	Rotation_Loss: 104.438744	Translation_Loss: 0.241165	Time: 0.79
Train Epoch: 11 	Rotation_Loss: 100.484970	Translation_Loss: 0.238274	Time: 0.75
Train Epoch: 12 	Rotation_Loss: 96.684029	Translation_Loss: 0.235482	Time: 0.77
Train Epoch: 13 	Rotation_Loss: 93.029

Train Epoch: 104 	Rotation_Loss: 3.684473	Translation_Loss: 0.156429	Time: 0.76
Train Epoch: 105 	Rotation_Loss: 3.574750	Translation_Loss: 0.156270	Time: 0.74
Train Epoch: 106 	Rotation_Loss: 3.468908	Translation_Loss: 0.156114	Time: 0.77
Train Epoch: 107 	Rotation_Loss: 3.366798	Translation_Loss: 0.155963	Time: 0.74
Train Epoch: 108 	Rotation_Loss: 3.268287	Translation_Loss: 0.155815	Time: 0.76
Train Epoch: 109 	Rotation_Loss: 3.173238	Translation_Loss: 0.155671	Time: 0.76
Train Epoch: 110 	Rotation_Loss: 3.081525	Translation_Loss: 0.155531	Time: 0.76
Train Epoch: 111 	Rotation_Loss: 2.993024	Translation_Loss: 0.155395	Time: 0.75
Train Epoch: 112 	Rotation_Loss: 2.907615	Translation_Loss: 0.155261	Time: 0.76
Train Epoch: 113 	Rotation_Loss: 2.825186	Translation_Loss: 0.155132	Time: 0.76
Train Epoch: 114 	Rotation_Loss: 2.745627	Translation_Loss: 0.155005	Time: 0.76
Train Epoch: 115 	Rotation_Loss: 2.668831	Translation_Loss: 0.154882	Time: 0.76
Train Epoch: 116 	Rotation_Loss: 2.59469

ValueError: Maximum allowed size exceeded