In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from data import ModelNet40,download,load_data
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from geomloss import SamplesLoss
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np
from matplotlib import pyplot as plt
import pickle
import warnings
import math
import time
from datetime import datetime
from collections import OrderedDict
from scipy.optimize import linear_sum_assignment
from torch import jit
%matplotlib inline
import os
from torch.utils.data import Dataset, TensorDataset
import importlib

In [2]:
train_losses = []
test_results_exp = []
test_accs = []

In [3]:
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 2 every 30 epochs"""
    lrt = lr * (0.5 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lrt

In [4]:
device = torch.device("cuda:0")

In [5]:
data = ModelNet40(num_points=1024, partition='train', gaussian_noise=False,
                       unseen=False, factor=4)
data_test = ModelNet40(num_points=1024, partition='test', gaussian_noise=False,
                       unseen=False, factor=4)
train_loader = DataLoader(data, batch_size=128, shuffle=True, drop_last=True)
test_loader = DataLoader(data_test, batch_size=128, shuffle=True, drop_last=True)
rl_loader = DataLoader(data_test, batch_size=1, shuffle=True, drop_last=True)

  f = h5py.File(h5_name)


In [6]:
def cal_rt_loss(ra,ta,rb,tb):
    r_loss = F.mse_loss(ra,rb).item()
    t_loss = F.mse_loss(ta,tb).item()
    return r_loss, t_loss

In [7]:
class PointNet(nn.Module):
    def __init__(self, emb_dims=512):
        super(PointNet, self).__init__()
        self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.conv3 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.conv4 = nn.Conv1d(64, 128, kernel_size=1, bias=False)
        self.conv5 = nn.Conv1d(128, emb_dims, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(emb_dims)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        return x

In [8]:
class CPNet(nn.Module):
    def __init__(self, emb_dims=512):
        super(CPNet, self).__init__()
        self.base_net = PointNet()
        self.loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)

        
    def forward(self, x1, x2, R, T):
        
        trans = torch.matmul(R, x1) + T.unsqueeze(2)
        emb1 = self.base_net(trans)
        emb2 = self.base_net(x2)
        
        dist = self.loss(emb1,emb2)

        return dist

In [140]:
class CPEvalNet(nn.Module):
    def __init__(self, base_net,batch_size, emb_dims=512):
        super(CPEvalNet, self).__init__()
        self.base_net = PointNet()
        self.base_net.load_state_dict(base_net.state_dict())
        for param in self.base_net.parameters():
            param.requires_grad = False
        self.batch_size = batch_size
        self.rotation = nn.Parameter(torch.Tensor(batch_size,3,3))
        self.translation = nn.Parameter(torch.Tensor(batch_size,3))
        nn.init.kaiming_uniform_(self.rotation, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.translation, a=math.sqrt(5))
        self.loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)

        
    def forward(self, x1, x2):
        
        trans = torch.matmul(self.rotation, x1) + self.translation.unsqueeze(2)
        emb1 = self.base_net(trans)
        emb2 = self.base_net(x2)
        
        dist = self.loss(emb1,emb2)

        return dist
    
    def cal_emb(self,x1,x2):
        trans = torch.matmul(self.rotation, x1) + self.translation.unsqueeze(2)
        emb1 = self.base_net(trans)
        emb2 = self.base_net(x2)
        return emb1,emb2

In [141]:
base_model = torch.load("no_pert_base.mdl").to(device)

In [142]:
def dummy_loss(output, target):
    loss = torch.mean(output)
    return loss

In [143]:
def rt_learning(model, device, src, target, rotation_ab, translation_ab, optimizer,criterion,epoch,n_batch = 40):
    r_losses = []
    t_losses = []
    for _ in range(n_batch):
        model.train()
        t1 = datetime.now()
        src = src.to(device)
        target = target.to(device)
        rotation_ab = rotation_ab.to(device)
        translation_ab = translation_ab.to(device)
        optimizer.zero_grad()
        output_train = model(src,target)
        loss = criterion(output_train, target)
        loss.backward()
        optimizer.step()
        r_loss,t_loss = cal_rt_loss(model.rotation, model.translation,rotation_ab, translation_ab)
        r_losses.append(r_loss)
        t_losses.append(t_loss)
    return r_losses, t_losses

In [144]:
r_losses = []
t_losses = []

In [147]:
for src_rl, target_rl, rotation_ab_rl, translation_ab_rl, rotation_ba_rl, translation_ba_rl,_,_ in rl_loader:
    model_eval = CPEvalNet(base_model.base_net,len(src_rl)).to(device)
    
    criterion = dummy_loss
    lr = 1e+1
    momentum=0.9
    weight_decay = 5e-4
    optimizer = torch.optim.SGD(model_eval.parameters(), lr,
                                    momentum=momentum,
                                    weight_decay=weight_decay)
    best=0
    
    for epoch in range(1,250):
        #adjust_learning_rate(optimizer, epoch)
        t1 = datetime.now()
        r_batch_loss, t_batch_loss = rt_learning(model_eval, device, src_rl, target_rl, rotation_ab_rl, translation_ab_rl, 
                                         optimizer,criterion,epoch,n_batch=1)
        if epoch%40==0:

            print('Train Epoch: {} \tRotation_Loss: {:.6f}\tTranslation_Loss: {:.6f}\tTime: {:.2f}'.format(
                epoch, r_batch_loss[-1],t_batch_loss[-1],(datetime.now()-t1).total_seconds()))
    print("*********************************************")

Train Epoch: 1 	Rotation_Loss: 5519.810059	Translation_Loss: 0.079069	Time: 0.03
Train Epoch: 2 	Rotation_Loss: 19877.384766	Translation_Loss: 0.078813	Time: 0.02
Train Epoch: 3 	Rotation_Loss: 40124.328125	Translation_Loss: 0.078457	Time: 0.02
Train Epoch: 4 	Rotation_Loss: 63859.472656	Translation_Loss: 0.078031	Time: 0.02
Train Epoch: 5 	Rotation_Loss: 89161.140625	Translation_Loss: 0.077562	Time: 0.02
Train Epoch: 6 	Rotation_Loss: 114523.625000	Translation_Loss: 0.077080	Time: 0.02
Train Epoch: 7 	Rotation_Loss: 138798.859375	Translation_Loss: 0.076610	Time: 0.02
Train Epoch: 8 	Rotation_Loss: 161142.812500	Translation_Loss: 0.076175	Time: 0.02
Train Epoch: 9 	Rotation_Loss: 180967.390625	Translation_Loss: 0.075795	Time: 0.02
Train Epoch: 10 	Rotation_Loss: 197897.062500	Translation_Loss: 0.075489	Time: 0.02
Train Epoch: 11 	Rotation_Loss: 211730.328125	Translation_Loss: 0.075269	Time: 0.02
Train Epoch: 12 	Rotation_Loss: 222405.750000	Translation_Loss: 0.075147	Time: 0.02
Train E

Train Epoch: 103 	Rotation_Loss: 137.350830	Translation_Loss: 0.102236	Time: 0.02
Train Epoch: 104 	Rotation_Loss: 124.652901	Translation_Loss: 0.102207	Time: 0.02
Train Epoch: 105 	Rotation_Loss: 113.344513	Translation_Loss: 0.102180	Time: 0.02
Train Epoch: 106 	Rotation_Loss: 103.710365	Translation_Loss: 0.102155	Time: 0.02
Train Epoch: 107 	Rotation_Loss: 95.264542	Translation_Loss: 0.102132	Time: 0.02
Train Epoch: 108 	Rotation_Loss: 87.795464	Translation_Loss: 0.102111	Time: 0.02
Train Epoch: 109 	Rotation_Loss: 81.553673	Translation_Loss: 0.102093	Time: 0.02
Train Epoch: 110 	Rotation_Loss: 76.335136	Translation_Loss: 0.102075	Time: 0.02
Train Epoch: 111 	Rotation_Loss: 71.553619	Translation_Loss: 0.102060	Time: 0.02
Train Epoch: 112 	Rotation_Loss: 67.211494	Translation_Loss: 0.102046	Time: 0.02
Train Epoch: 113 	Rotation_Loss: 63.591675	Translation_Loss: 0.102034	Time: 0.02
Train Epoch: 114 	Rotation_Loss: 61.760025	Translation_Loss: 0.102023	Time: 0.02
Train Epoch: 115 	Rotati

Train Epoch: 213 	Rotation_Loss: 36.210506	Translation_Loss: 0.102081	Time: 0.02
Train Epoch: 214 	Rotation_Loss: 39.536873	Translation_Loss: 0.102080	Time: 0.02
Train Epoch: 215 	Rotation_Loss: 42.900242	Translation_Loss: 0.102080	Time: 0.02
Train Epoch: 216 	Rotation_Loss: 45.840103	Translation_Loss: 0.102080	Time: 0.02
Train Epoch: 217 	Rotation_Loss: 48.164558	Translation_Loss: 0.102080	Time: 0.02
Train Epoch: 218 	Rotation_Loss: 50.027943	Translation_Loss: 0.102080	Time: 0.02
Train Epoch: 219 	Rotation_Loss: 51.456352	Translation_Loss: 0.102080	Time: 0.02
Train Epoch: 220 	Rotation_Loss: 52.450825	Translation_Loss: 0.102079	Time: 0.02
Train Epoch: 221 	Rotation_Loss: 52.995197	Translation_Loss: 0.102079	Time: 0.02
Train Epoch: 222 	Rotation_Loss: 52.995380	Translation_Loss: 0.102079	Time: 0.02
Train Epoch: 223 	Rotation_Loss: 52.560200	Translation_Loss: 0.102079	Time: 0.02
Train Epoch: 224 	Rotation_Loss: 51.822041	Translation_Loss: 0.102079	Time: 0.02
Train Epoch: 225 	Rotation_L

Train Epoch: 70 	Rotation_Loss: 15815.524414	Translation_Loss: 0.102659	Time: 0.02
Train Epoch: 71 	Rotation_Loss: 18348.115234	Translation_Loss: 0.102631	Time: 0.02
Train Epoch: 72 	Rotation_Loss: 20599.269531	Translation_Loss: 0.102601	Time: 0.02
Train Epoch: 73 	Rotation_Loss: 22525.765625	Translation_Loss: 0.102570	Time: 0.02
Train Epoch: 74 	Rotation_Loss: 24104.119141	Translation_Loss: 0.102538	Time: 0.02
Train Epoch: 75 	Rotation_Loss: 25326.699219	Translation_Loss: 0.102506	Time: 0.02
Train Epoch: 76 	Rotation_Loss: 26198.371094	Translation_Loss: 0.102473	Time: 0.02
Train Epoch: 77 	Rotation_Loss: 26733.603516	Translation_Loss: 0.102439	Time: 0.02
Train Epoch: 78 	Rotation_Loss: 26953.996094	Translation_Loss: 0.102406	Time: 0.02
Train Epoch: 79 	Rotation_Loss: 26886.181641	Translation_Loss: 0.102373	Time: 0.02
Train Epoch: 80 	Rotation_Loss: 26560.058594	Translation_Loss: 0.102340	Time: 0.02
Train Epoch: 81 	Rotation_Loss: 26007.316406	Translation_Loss: 0.102307	Time: 0.02
Trai

Train Epoch: 180 	Rotation_Loss: 1074.671875	Translation_Loss: 0.101771	Time: 0.02
Train Epoch: 181 	Rotation_Loss: 2054.594238	Translation_Loss: 0.101770	Time: 0.02
Train Epoch: 182 	Rotation_Loss: 3197.641846	Translation_Loss: 0.101769	Time: 0.02
Train Epoch: 183 	Rotation_Loss: 4415.334473	Translation_Loss: 0.101768	Time: 0.02
Train Epoch: 184 	Rotation_Loss: 5640.160156	Translation_Loss: 0.101767	Time: 0.02
Train Epoch: 185 	Rotation_Loss: 6818.159668	Translation_Loss: 0.101767	Time: 0.02
Train Epoch: 186 	Rotation_Loss: 7908.227539	Translation_Loss: 0.101766	Time: 0.02
Train Epoch: 187 	Rotation_Loss: 8881.011719	Translation_Loss: 0.101766	Time: 0.02
Train Epoch: 188 	Rotation_Loss: 9717.230469	Translation_Loss: 0.101765	Time: 0.02
Train Epoch: 189 	Rotation_Loss: 10405.958008	Translation_Loss: 0.101765	Time: 0.02
Train Epoch: 190 	Rotation_Loss: 10943.080078	Translation_Loss: 0.101764	Time: 0.02
Train Epoch: 191 	Rotation_Loss: 11329.935547	Translation_Loss: 0.101764	Time: 0.02
T

Train Epoch: 37 	Rotation_Loss: 28087.378906	Translation_Loss: 0.123153	Time: 0.02
Train Epoch: 38 	Rotation_Loss: 24510.291016	Translation_Loss: 0.124634	Time: 0.02
Train Epoch: 39 	Rotation_Loss: 21270.873047	Translation_Loss: 0.126082	Time: 0.02
Train Epoch: 40 	Rotation_Loss: 18354.183594	Translation_Loss: 0.127486	Time: 0.02
Train Epoch: 41 	Rotation_Loss: 15743.200195	Translation_Loss: 0.128836	Time: 0.02
Train Epoch: 42 	Rotation_Loss: 13419.388672	Translation_Loss: 0.130125	Time: 0.02
Train Epoch: 43 	Rotation_Loss: 11363.314453	Translation_Loss: 0.131348	Time: 0.02
Train Epoch: 44 	Rotation_Loss: 9555.061523	Translation_Loss: 0.132499	Time: 0.02
Train Epoch: 45 	Rotation_Loss: 7974.543457	Translation_Loss: 0.133576	Time: 0.02
Train Epoch: 46 	Rotation_Loss: 6601.802734	Translation_Loss: 0.134576	Time: 0.02
Train Epoch: 47 	Rotation_Loss: 5417.520508	Translation_Loss: 0.135498	Time: 0.02
Train Epoch: 48 	Rotation_Loss: 4403.165039	Translation_Loss: 0.136343	Time: 0.02
Train Epo

KeyboardInterrupt: 

In [37]:
model_eval = CPEvalNet(base_model.base_net,len(src_rl)).to(device)

In [47]:
src = src_rl.to(device)
target = target_rl.to(device)

In [48]:
emb1,emb2 = model_eval.cal_emb(src,target)

In [49]:
emb1

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
       grad_fn=<ReluBackward0>)

In [55]:
emb2

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
       grad_fn=<ReluBackward0>)

In [36]:
target

tensor([[[-0.0518, -0.1208, -0.6257,  ..., -0.0921,  0.6426,  0.3896],
         [ 0.2355,  0.4974,  0.3281,  ...,  0.2768, -0.1622,  0.3229],
         [ 0.6433, -0.4532, -0.0385,  ..., -0.1564,  0.9694,  0.5029]]],
       device='cuda:0')

In [41]:
model_eval.translation

Parameter containing:
tensor([[-0.3587,  0.5572,  0.3573]], device='cuda:0', requires_grad=True)

In [133]:
class Tnet(nn.Module):
    def __init__(self):
        super(Tnet, self).__init__()
        self.base_net = nn.Linear(4,5)
        for param in self.base_net.parameters():
            param.requires_grad = False
        self.linear = nn.Linear(5,10)

        
    def forward(self, x1):
        
        y = self.base_net(x1)
        y = self.linear(y)

        return y

In [134]:
mdl_fix = Tnet()

In [136]:
a = torch.randn(128,4)
b = torch.randint(9,(128,))

In [137]:
loss_fix = nn.CrossEntropyLoss()
lr = 1e-0
momentum=0.9
weight_decay = 5e-4
optimizer = torch.optim.SGD(mdl_fix.parameters(), lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

In [138]:
optimizer.zero_grad()
o = mdl_fix(a)
loss = loss_fix(o,b)
loss.backward()
optimizer.step()

In [135]:
mdl_fix.base_net.weight

Parameter containing:
tensor([[-0.3622,  0.4383, -0.0323, -0.3923],
        [-0.4348, -0.3971,  0.2977, -0.1876],
        [ 0.0916, -0.5000, -0.0444,  0.3572],
        [ 0.3948, -0.4296,  0.4550,  0.1864],
        [ 0.4938,  0.4447,  0.4731, -0.3320]])

In [112]:
mdl_fix.linear.weight

Parameter containing:
tensor([[-0.2740, -0.3083, -0.1948, -0.2773,  0.2340],
        [ 0.3289,  0.2262, -0.1302,  0.2608,  0.3197],
        [-0.4101, -0.3186, -0.3246, -0.0478,  0.1414],
        [ 0.0234, -0.2619, -0.0990, -0.0118, -0.1745],
        [ 0.2538,  0.1531, -0.3491,  0.0481, -0.3088],
        [-0.0593, -0.3752, -0.1289,  0.2848,  0.3547],
        [ 0.4445,  0.1241, -0.2501, -0.4075, -0.2104],
        [-0.0064,  0.3962, -0.0112, -0.4004, -0.0814],
        [-0.4317, -0.3341, -0.3595, -0.1399,  0.1505],
        [ 0.0969, -0.0780, -0.1750, -0.0828, -0.2956]], requires_grad=True)

In [139]:
mdl_fix.base_net.weight

Parameter containing:
tensor([[-0.3622,  0.4383, -0.0323, -0.3923],
        [-0.4348, -0.3971,  0.2977, -0.1876],
        [ 0.0916, -0.5000, -0.0444,  0.3572],
        [ 0.3948, -0.4296,  0.4550,  0.1864],
        [ 0.4938,  0.4447,  0.4731, -0.3320]])

In [104]:
mdl_fix.base_net.weight.requires_grad

True