!wget https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_segmentation_benchmark_v0.zip

In [1]:
%%capture

import numpy as np
import time
import utils
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import models
import torch.optim as optim
from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter

In [2]:
batch_size = 64
log_folder = "logs/" # folder path to save the results
save_results = True # save the results to log_folder
use_GPU = True # use GPU, False to use CPU
latent_size = 128 # bottleneck size of the Autoencoder model
use_TB = False

if(save_results):
    utils.clear_folder(log_folder)
    log_file = open(log_folder +"logs.txt" ,"a") # open log file
if(use_TB):
    writer = SummaryWriter(log_folder + "TB")

In [3]:
from Dataset import get_dataset
from torch.utils.data import TensorDataset, DataLoader

category = "Chair"
point_size = 2048
image_resolution = 224

train_pc, train_im = get_dataset(category, "train", point_size, image_resolution)
val_pc, val_im = get_dataset(category, "validation", point_size, image_resolution)
#train_pc = np.load("data/train_pc.npy")
#train_im = np.load("data/train_im.npy")
#val_pc = np.load("data/val_pc.npy")
#val_im = np.load("data/val_im.npy")

train_pc = train_pc[:,:,0:3]
val_pc = val_pc[:,:,0:3]
train_im /= 255.0
val_im /= 255.0

print(train_pc.shape)
print(train_im.shape)
print(val_pc.shape)
print(val_im.shape)

train_pc_tensor = torch.from_numpy(train_pc).float()
train_im_tensor = torch.from_numpy(train_im).float()
val_pc_tensor = torch.from_numpy(val_pc).float()
val_im_tensor = torch.from_numpy(val_im).float()

train_set = TensorDataset(train_im_tensor, train_pc_tensor)
val_set = TensorDataset(val_im_tensor, val_pc_tensor)

train_loader = DataLoader(dataset=train_set , batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size, shuffle=True,  pin_memory=True)

(2658, 2048, 3)
(2658, 224, 224, 3)
(396, 2048, 3)
(396, 224, 224, 3)


In [4]:
model = models.IMtoPC(latent_size, point_size)
device = torch.device("cuda:0")
model = model.to(device)
model

IMtoPC(
  (vgg): Vgg16(
    (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv4_1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv4_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv4_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv5_1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv5_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), pa

In [5]:
from pytorch3d.loss import chamfer_distance # chamfer distance for calculating point cloud distance

def rec_criterion(pc1, pc2):
    loss, _ = chamfer_distance(pc1, pc2)
    return loss

optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [6]:
def train_epoch():
    
    model.train()
    t_rec_loss = []
    
    for i, data in enumerate(train_loader):
        model.zero_grad()
        
        im_batch = data[0].to(device)
        pc_batch = data[1].to(device)
        
        output = model(im_batch)
        rec_loss = rec_criterion(pc_batch, output)
            
        rec_loss.backward()
        optimizer.step()
        
        t_rec_loss.append(rec_loss.item())

    return np.mean(t_rec_loss)

In [7]:
def test_epoch():
    
    model.eval()
    t_rec_loss = []
    
    with torch.no_grad():
    
        for i, data in enumerate(val_loader):
              
            im_batch = data[0].to(device)
            pc_batch = data[1].to(device)

            output = model(im_batch)
            
            rec_loss = rec_criterion(pc_batch, output)
            t_rec_loss.append(rec_loss.item())

    return np.mean(t_rec_loss)

def test_epoch(epoch_n):
    
    t_rec_loss, t_seg_loss , t_accuracy = 0,0,0
    
    with torch.no_grad():
    
        for i, data in enumerate(val_loader):

            labels = data[:,:,3].to(device).long()

            points = data[:,:,0:3].to(device)

            seg_results, output = model(points)

            rec_loss = rec_criterion(points, output)

            seg_loss = seg_criterion( seg_results.view(-1,part_count+1) ,labels.view(-1))

            seg_labels = seg_results.argmax(dim=2,keepdim=True).squeeze()
            correct = seg_labels.eq(labels.data).cpu().sum()
            accuracy = correct.item()/float(data.shape[0]*data.shape[1])

            t_rec_loss += rec_loss.item()
            t_seg_loss += seg_loss.item()
            t_accuracy += accuracy
        
    return t_rec_loss/(i+1) , t_seg_loss/(i+1), t_accuracy/(i+1)

In [8]:
def test_batch(data): # test with a batch of inputs
    model.eval()
    
    with torch.no_grad():

        im_batch = data.to(device)
        
        output = model(im_batch)

    return output.cpu().numpy()

In [None]:
train_list = []
test_list = []

pbar = tqdm( range(1001) )
for i in pbar :

    startTime = time.time()
    
    train_rec_loss = train_epoch()
 
    test_rec_loss = test_epoch() # test with test set
    
    epoch_time = time.time() - startTime
    
    train_list.append(train_rec_loss)
    test_list.append(test_rec_loss)
    
    utils.plot_graph([train_list,test_list], log_folder + "loss_graph") # plot loss graph up to that epoch

    epoch_time = time.time() - startTime
    
    writeString = "epoch %d --> train:%0.6f test:%0.6f time:%0.3f" % (i, train_rec_loss, test_rec_loss, epoch_time) # generate log string

    pbar.set_description(writeString)
    log_file.write(writeString + "\n") # write to log file
    log_file.flush()
    
    if(i%10 == 0):
        data = next(iter(val_loader))
        rec = test_batch(data[0])
        
        utils.showIMs(data[0], show=False, save=True, name= log_folder +"ims" + str(i))
        utils.plotPC(rec, show=False, save=True, name= log_folder +"pcs" + str(i))


epoch 1 --> train:0.003058 test:0.002993 time:26.388:   0%|          | 2/1001 [00:53<7:25:11, 26.74s/it]

In [None]:
torch.save(model.cpu().state_dict(), log_folder + "model_state_dict")
torch.save(model.cpu(), log_folder + "model_save")