In [None]:
import numpy as np
import time
import utils
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import model
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

In [None]:
batch_size = 256
log_folder = "logs/" # folder path to save the results
save_results = True # save the results to log_folder
latent_size = 128 # bottleneck size of the Autoencoder model

category = "Chair"
n_points = 2048

if(save_results):
    utils.clear_folder(log_folder)
    writer = SummaryWriter(log_folder + "TB")

In [None]:
from data.load_dataset import get_dataset
from torch.utils.data import TensorDataset, DataLoader

train_set = get_dataset(category, "train", n_points)
val_set = get_dataset(category, "validation", n_points)

part_count = int(train_set.max())

print("Train set shape :" + str(train_set.shape))
print("Validation set shape :" + str(val_set.shape))
print("Number of points : " + str(n_points))
print("Part count : " + str(part_count))

train_tensor = torch.from_numpy(train_set).float()
val_tensor = torch.from_numpy(val_set).float()

train_loader = DataLoader(dataset=train_tensor, batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = DataLoader(dataset=val_tensor, batch_size=batch_size, shuffle=True,  pin_memory=True)

In [None]:
model = model.LPMNet(n_points, latent_size, part_count)
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
model

In [None]:
from pytorch3d.loss import chamfer_distance # chamfer distance for calculating point cloud distance

def rec_criterion(pc1, pc2):
    loss, _ = chamfer_distance(pc1, pc2)
    return loss

seg_criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

In [None]:
def train_epoch(epoch_n):
    model.train()
    t_rec_loss, t_seg_loss , t_accuracy = 0,0,0
    
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        
        labels = data[:,:,3].to(device).long()
        
        points = data[:,:,0:3].to(device)
        
        seg_results, output = model(points)
        rec_loss = rec_criterion(points, output)
        
        seg_loss = seg_criterion( seg_results.view(-1,part_count+1) ,labels.view(-1))

        seg_labels = seg_results.argmax(dim=2,keepdim=True).squeeze()
        correct = seg_labels.eq(labels.data).cpu().sum()
        accuracy = correct.item()/float(data.shape[0]*data.shape[1])
        
        loss = rec_loss + seg_loss
            
        loss.backward()
        optimizer.step()
        
        t_rec_loss += rec_loss.item()
        t_seg_loss += seg_loss.item()
        t_accuracy += accuracy
        
    model.eval()
    return t_rec_loss/(i+1) , t_seg_loss/(i+1), t_accuracy/(i+1)

In [None]:
def test_epoch(epoch_n):
    
    t_rec_loss, t_seg_loss , t_accuracy = 0,0,0
    
    with torch.no_grad():
    
        for i, data in enumerate(val_loader):

            labels = data[:,:,3].to(device).long()

            points = data[:,:,0:3].to(device)

            seg_results, output = model(points)

            rec_loss = rec_criterion(points, output)

            seg_loss = seg_criterion( seg_results.view(-1,part_count+1) ,labels.view(-1))

            seg_labels = seg_results.argmax(dim=2,keepdim=True).squeeze()
            correct = seg_labels.eq(labels.data).cpu().sum()
            accuracy = correct.item()/float(data.shape[0]*data.shape[1])

            t_rec_loss += rec_loss.item()
            t_seg_loss += seg_loss.item()
            t_accuracy += accuracy
        
    return t_rec_loss/(i+1) , t_seg_loss/(i+1), t_accuracy/(i+1)

In [None]:
def test_batch(data): # test with a batch of inputs
    with torch.no_grad():
        
        labels = data[:,:,3].to(device).long()
        points = data[:,:,0:3].to(device)
        
        seg_results, output = model(points)
        rec_loss = rec_criterion(points, output)
        
        seg_loss = seg_criterion( seg_results.view(-1,part_count+1) ,labels.view(-1))

        seg_labels = seg_results.argmax(dim=2,keepdim=True).squeeze()
        correct = seg_labels.eq(labels.data).cpu().sum()
        accuracy = correct.item()/float(data.shape[0]*data.shape[1])
        
        loss = seg_loss + rec_loss
        
    return accuracy, rec_loss.item(), seg_loss.item(), output.cpu()

In [None]:
def segmentall(pc):
    
    t_data = torch.cat([pc, torch.zeros([pc.shape[0],n_points,1]).to(pc.device)],2)

    seg_results, output = model(t_data.to(device))
        
    seg_labels = seg_results.argmax(dim=2,keepdim=True).squeeze()
        
    t_data[:,:,3] = seg_labels
    
    return t_data.cpu().detach().numpy()

In [None]:
for i in range(101) :

    startTime = time.time()
    
    train_rec_loss, train_seg_loss, train_acc = train_epoch(i)
 
    test_rec_loss, test_seg_loss, test_acc = test_epoch(i) # test with test set
    
    epoch_time = time.time() - startTime
    
    writeString = "epoch " + str(i) + " epoch time : " + str(epoch_time) + "\n" + \
                  "Train --> Rec : " + str(train_rec_loss) + " Seg : " + str(train_seg_loss) + " Acc : " + str(train_acc) + "\n" + \
                  "Validation --> Rec : " + str(test_rec_loss) + " Seg : " + str(test_seg_loss) + " Acc : " + str(test_acc) + "\n"
    
    if(save_results): # save all outputs to the save folder
        
        writer.add_scalars('Loss/Reconstruction', {'Train':train_rec_loss, 'Test':test_rec_loss}, i)
        writer.add_scalars('Loss/Segmentation', {'Train':train_seg_loss, 'Test':test_seg_loss}, i)
        writer.add_scalars('Loss/Accuracy', {'Train':train_acc, 'Test':test_acc}, i)

        with open(log_folder + "prints.txt","a") as file: 
            file.write(writeString)

        if(i%50==0):
            test_samples = next(iter(val_loader))
            _, rec_loss, seg_loss, test_output = test_batch(test_samples)
            
            #ims = utils.plotPC([test_samples.numpy(), test_output], show=False)
            #writer.add_figure("ims", ims, i)
            writer.flush()
            
        if(i%50==0):
            utils.plotPC([test_samples[0:10].numpy(), segmentall(test_output[0:10])], show=False, save=True, name = (log_folder  + "epoch_" + str(i)))

    else : # display all outputs
        
        test_samples = next(iter(val_loader))
        loss , test_output = test_batch(test_samples)
        utils.plotPC([test_samples,test_output])

        print(writeString)

        plt.show()

        


In [None]:
torch.save(model.cpu().state_dict(), log_folder + "model_state_dict")
torch.save(model.cpu(), log_folder + "model_save")