In [None]:
from model.model import *
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pickle
from dataset import ShapeNetDataset
import numpy as np
import math
%matplotlib inline
from IPython.display import clear_output
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:

def sample_data(dataset,data_type,batch_size):
    pcds,labels,fns = dataset.get_data(data_type,batch_size)
    return pcds,labels
    

def compute_mIoU(preds,target):
    preds = preds.cpu().detach().numpy()
    target = target.cpu().detach().numpy()
    N_of_parts = np.max(target)
    N_of_pts = np.shape(target)[0]
    total_IoU = 0.
    for i in range(N_of_parts):
        U = [False]*N_of_pts
        I = [False]*N_of_pts
        for j in range(N_of_pts):
            if target[j] == i and preds[j] == i:
                I[j] = True
            if target[j] == i or preds[j] == i:
                U[j] = True
        if sum(U) == 0: 
            total_IoU += 1
        else:
            total_IoU += sum(I)/sum(U)
    return total_IoU/N_of_parts


In [None]:
##Hyperparameters
N_CLASSES = 16
EPOCHS = 1000#2000
BATCH_SIZE = 32
INIT_LR = 0.001
MOMENTUM = 0.9
LR_STEP = 20
SCHEDULER_GAMMA = 0.5
VAL_EVERY = 1
REG_WEIGHT = 0.001
criterion = nn.CrossEntropyLoss()

In [None]:
class_ID = 0
ShapeNetData = ShapeNetDataset('datasets/ShapeNet/',class_ID)
N_of_parts = ShapeNetData.get_N_parts(class_ID)
net = PointNetDenseClassification(N_of_parts).to(device)
optimizer = optim.Adam(params=net.parameters(), lr=INIT_LR)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEP, gamma=SCHEDULER_GAMMA)


In [None]:
##training loop
train_losses = []
train_iterations = []
val_losses = []
val_iterations = []


for epoch in range(EPOCHS):
    optimizer.zero_grad()
    net.eval()#to allow for batch of 1
    #load the batch of data
    all_data,all_labels = sample_data(ShapeNetData,'train',BATCH_SIZE)
    mIoU_batch = 0.
    for data,labels in zip(all_data,all_labels):
        data = torch.from_numpy(np.expand_dims(np.array(data),0)).float().to(device)
        labels = torch.from_numpy(np.expand_dims(np.array(labels),0)).to(device)
        
        #compute the loss
        preds,M2 = net(data)
        loss = criterion(preds[0,:,:],labels[0,:])
    
        #add transformation matrix regularization loss
        I = torch.eye(64).unsqueeze(0).to(device)
        loss2 = torch.mean(torch.norm(torch.bmm(M2,M2.transpose(2,1)) - I, dim=(1,2)))    
        loss += REG_WEIGHT*loss2
        
        mIoU_one_part = compute_mIoU(torch.max(preds[0,:,:],dim = 1).values,labels[0,:])
        mIoU_batch +=mIoU_one_part 
        
    train_losses.append(loss.detach().cpu())
    train_iterations.append(epoch)
    
    
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    print("epoch:",epoch,"mIoU:",mIoU_batch/BATCH_SIZE)
    if epoch%VAL_EVERY == 0:
        with torch.no_grad():
            net.eval()
            #load the batch of eval data (batch size couldn't be too big)
            all_data,all_labels = sample_data(ShapeNetData,'val',BATCH_SIZE)
            for data,labels in zip(all_data,all_labels):
                data = torch.from_numpy(np.expand_dims(np.array(data),0)).float().to(device)
                labels = torch.from_numpy(np.expand_dims(np.array(labels),0)).to(device)

                #compute the loss
                preds,M2 = net(data)
                val_loss = criterion(preds[0,:,:],labels[0,:])

                #add transformation matrix regularization loss
                I = torch.eye(64).unsqueeze(0).to(device)
                loss2 = torch.mean(torch.norm(torch.bmm(M2,M2.transpose(2,1)) - I, dim=(1,2)))    
                val_loss += REG_WEIGHT*loss2

            val_losses.append(val_loss.detach().cpu())
            val_iterations.append(epoch)

# #     clear_output()
    plt.plot(train_iterations, train_losses, 'b',val_iterations, val_losses, 'r')
    plt.xlabel('Epoch')
    plt.ylabel('Loss') 
    plt.legend(['Train','Val'])
    plt.title('Epoch vs Loss')
#     plt.show()
    plt.savefig("./part_seg_losses.png") # save graph for training visualization

In [None]:
# ###save stuff
folder = 'results/'
torch.save(net.state_dict(), folder+str(class_ID)+'_part_seg_model')
filehandler = open(folder+str(class_ID)+'_part_seg_train_iterations', 'wb') 
pickle.dump(train_iterations, filehandler)
filehandler = open(folder+str(class_ID)+'_part_seg_train_losses', 'wb') 
pickle.dump(train_losses, filehandler)
filehandler = open(folder+str(class_ID)+'_part_seg_val_iterations', 'wb') 
pickle.dump(val_iterations, filehandler)
filehandler = open(folder+str(class_ID)+'_part_seg_val_losses', 'wb') 
pickle.dump(val_losses, filehandler)

In [None]:
##plot training loss
import pickle
filehandler = open(folder+str(class_ID)+'_part_seg_train_iterations', 'rb') 
train_iterations = pickle.load(filehandler)
filehandler = open(folder+str(class_ID)+'_part_seg_train_losses', 'rb') 
train_losses = pickle.load(filehandler)
filehandler = open(folder+str(class_ID)+'_part_seg_val_iterations', 'rb') 
test_iterations = pickle.load(filehandler)
filehandler = open(folder+str(class_ID)+'_part_seg_val_losses', 'rb') 
test_losses = pickle.load(filehandler)

import os,math,numpy as np
import matplotlib
import matplotlib.pyplot as plt
import scipy.interpolate

matplotlib.rcParams['axes.linewidth'] = 5
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rc('font',size=35)

iter=train_iterations
fig=plt.figure()
ax=fig.add_subplot(1,1,1)
fig.set_size_inches(12*1.5,8*1.5)
xnew=np.linspace(min(iter),max(iter),200)  
test_loss_spline=scipy.interpolate.CubicSpline(iter,test_losses)
test_data = test_loss_spline(xnew)

ln=ax.plot(xnew,test_data,label="Validation Loss",linewidth=5,color='#ff7f0e')
ax.tick_params(axis='y',colors='#ff7f0e')

ax2=ax.twinx()
xnew=np.linspace(min(iter),max(iter),200)  
train_loss_spline=scipy.interpolate.CubicSpline(iter,train_losses)
train_data = train_loss_spline(xnew)
ln+=ax2.plot(xnew,train_data,label="Train Loss",linewidth=5,color='#1f77b4')
ax2.tick_params(axis='y',colors='#1f77b4')
ax2.spines['right'].set_color('#1f77b4')
ax2.spines['left'].set_color('#ff7f0e')

labs=[l.get_label() for l in ln]
ax.legend(ln,labs,loc=0)

#ax.set_ylabel("Value")
ax.set_xlabel("#Iteration")
plt.savefig("Iteration.pdf",bbox_inches='tight',pad_inches=0)
plt.show()

In [None]:
## compute accuracy for the test set 

def get_all_data(dataset,data_type):
    pcds,labels,fns = dataset.get_all_data(data_type)
    return pcds,labels

all_data,all_labels = get_all_data(ShapeNetData,'test')
mIoU = 0.
for data,labels in zip(all_data,all_labels):
    data = torch.from_numpy(np.expand_dims(np.array(data),0)).float().to(device)
    labels = torch.from_numpy(np.expand_dims(np.array(labels),0)).to(device)

    #compute the loss
    preds,M2 = net(data)
    loss = criterion(preds[0,:,:],labels[0,:])

    #add transformation matrix regularization loss
    I = torch.eye(64).unsqueeze(0).to(device)
    loss2 = torch.mean(torch.norm(torch.bmm(M2,M2.transpose(2,1)) - I, dim=(1,2)))    
    loss += REG_WEIGHT*loss2

    mIoU_one_part = compute_mIoU(torch.max(preds[0,:,:],dim = 1).values,labels[0,:])
    mIoU +=mIoU_one_part
    
print('test mIoU:',mIoU/len(all_data))