In [None]:
import math
from tqdm import tqdm

from utils import *
from datasets import *

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import (
    Compose,
    Grayscale,
    Resize,
    ToTensor,
    CenterCrop
)

from model import PrimitivesNet
from losses import *

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device is : {device}")

In [None]:
!nvidia-smi

In [None]:
# seed for reproducibility
set_random_seed(24)

### Data Preparation and Exploration

TableChair Dataset

In [None]:
data_path = 'datasets/table_chair.h5'
batch_size=128
transforms = Compose([Resize(64), Grayscale(), ToTensor()])

In [None]:
def dataloader(
        data_path: str, training: bool, split_type: Literal["train", "valid", "test"], batch_size: int = 32, transforms: t.Optional[
        t.Callable[[np.ndarray], t.Union[torch.Tensor, np.ndarray]]
    ] = None,
    ) -> DataLoader:
        transforms = transforms
        loader = DataLoader(
            dataset=TableChairDataset(data_path, split_type, transforms),
            batch_size=batch_size,
            shuffle=training,
            drop_last=training,
            num_workers=0,
        )
        return loader

train_loader = dataloader(data_path, True, "train", batch_size=batch_size, transforms=transforms)
valid_loader = dataloader(data_path, False, "valid", batch_size=batch_size, transforms=transforms)
test_loader = dataloader(data_path, False, "test", batch_size=batch_size, transforms=transforms)

Simple Dataset

In [None]:
data_path = 'datasets/simple_dataset'
images_path = list(map(lambda x: os.path.join(os.path.abspath(data_path), x),os.listdir(data_path)))
batch_size=128
transforms = Compose([Resize(64), Grayscale(), ToTensor()])

In [None]:
simple_dataset = SimpleDataset(image_paths=images_path, transforms=transforms)

In [None]:
# Spliting dataset 
train_sampler, valid_sampler = split_dataset(simple_dataset, valid_size=0.2)

In [None]:
train_loader = torch.utils.data.DataLoader(
    simple_dataset, batch_size=batch_size, sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(
    simple_dataset, batch_size=batch_size, sampler=valid_sampler)

MNIST Dataset

In [None]:
data_path = 'datasets/'
batch_size=128
transforms = Compose([Resize(64), ToTensor()])

In [None]:
from torchvision.datasets import MNIST

MNIST_data = MNIST(root=data_path, download=True, train=True, transform=transforms)
MNIST_data_test = MNIST(root=data_path, download=True, train=False, transform=transforms)
#len(MNIST_data)
mnist_dataset = MNISTDataset(dataset=MNIST_data, transforms=transforms)
mnist_dataset_test = MNISTDataset(dataset=MNIST_data_test, transforms=transforms)

In [None]:
# Spliting dataset 
train_sampler, valid_sampler = split_dataset(mnist_dataset, valid_size=0.2)

train_loader = torch.utils.data.DataLoader(
    mnist_dataset, batch_size=batch_size, sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(
    mnist_dataset, batch_size=batch_size, sampler=valid_sampler)
test_loader = torch.utils.data.DataLoader(
    mnist_dataset_test, batch_size=batch_size)

Pet Dataset

In [None]:
data_path = 'datasets/pet'
batch_size=128
transforms = Compose([ToTensor()])

tf_rgb = Compose([
    ToTensor(),
    Resize((64, 64))
    #CenterCrop((64, 64))
    #T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    
])

tf_gray = Compose([
    #CenterCrop((64,64)),
    Resize((64, 64)),
    PILToTensor_for_targets(),
])

In [None]:
from torchvision.datasets import OxfordIIITPet

pet_dataset_trainval = OxfordIIITPet(root = data_path, split="trainval", target_types="segmentation", transform=tf_rgb, target_transform=tf_gray, download=True)

pet_dataset_test = OxfordIIITPet(root = data_path, split="test", target_types="segmentation", transform=tf_rgb, target_transform=tf_gray, download=True)

In [None]:
pet_dataset_trainval = PetDataset(dataset=pet_dataset_trainval, transforms=transforms)
pet_dataset_test = PetDataset(dataset=pet_dataset_test, transforms=transforms)

In [None]:
# Spliting dataset 
train_sampler, valid_sampler = split_dataset(pet_dataset_trainval, valid_size=0.2)

train_loader = torch.utils.data.DataLoader(
    pet_dataset_trainval, batch_size=batch_size, sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(
    pet_dataset_trainval, batch_size=batch_size, sampler=valid_sampler)
test_loader = torch.utils.data.DataLoader(
    pet_dataset_test, batch_size=batch_size)

In [None]:
# Get a batch of training data and display it
inputs = next(iter(train_loader))

show_grid(inputs[0])

In [None]:
num_channels = 1 # 3 for RGB images
latent_size = 256 # final output size of encoder
lr = 1e-3 # learning rate
num_epochs = 50
num_shape_type = 8 # number of shapes per type
threshold = 0.5 #Thresholding value for weights. If weight > threshold, 1 else 0

prev_CD = 100

model_name = 'primitives_net_ct13'

In [None]:
net = PrimitivesNet(num_channels=num_channels, latent_size=latent_size)
net.to(device)

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

In [None]:
import tensorboard_logger
from tensorboard_logger import log_value

tensorboard_logger.configure("logs/tensorboard/{}".format(model_name), flush_secs=5)

In [None]:
for epoch in range(num_epochs):
    
    net.train()
    pbar = tqdm(total=len(train_loader.dataset), leave=False)
    epoch_str = '' if epoch is None else '[Epoch {}/{}]'.format(
            str(epoch).zfill(len(str(num_epochs))), num_epochs)

    train_loss = 0.0
    n = 0.0

    for batch in train_loader:
        optimizer.zero_grad()
        
        image, pt, dist = batch
        
        image = image.to(device)
        pt = pt.to(device)
        dist = dist.to(device)

        pred = net(image, pt)

        loss = total_loss(pred, dist, net.scaler, net.shape_evaluator, net.boolean_layer)
     
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        n += 1
 
        pbar.set_description('{} {} Loss: {:f}'.format(epoch_str, 'Train', loss.item()))
        pbar.update(image.shape[0])

    pbar.close()
    mean_train_loss = train_loss / n

    print("Epoch {}/{} => train_loss: {}".format(epoch, num_epochs, mean_train_loss))
    log_value('train_loss', mean_train_loss, epoch)

    # evaluation
    net.eval()
    pbar = tqdm(total=len(valid_loader.dataset), leave=False)

    valid_loss = 0.0
    valid_CD = 0.0
    valid_IoU = 0.0
    n = 0.0

    for batch in valid_loader:
        with torch.no_grad():
            image, pt, dist = batch
            image = image.to(device)
            pt = pt.to(device)
            dist = dist.to(device)
            pred = net(image, pt)

            loss = total_loss(pred, dist, net.scaler, net.shape_evaluator, net.boolean_layer)
            valid_loss += loss.item()
            n += 1

            #pred = net.binarize(pred).clone()#.reshape(-1, 64, 64).clone().cpu().numpy()
            pred = net.binarize(pred).reshape(-1, 64, 64).clone().cpu().numpy()
            dist = dist.reshape(-1, 64, 64).clone().cpu().numpy()

            CD = chamfer_distance(pred, dist)
            IoU = iou(pred, dist)
            valid_CD += CD
            valid_IoU += IoU

        pbar.set_description('{} {} Loss: {:f}, CD: {:f}, IoU: {:f}%'.format(epoch_str, 'Valid', loss.item(), CD, IoU*100))
        pbar.update(image.shape[0])

    pbar.close()

    mean_valid_loss = valid_loss / n
    valid_CD = valid_CD / n 
    valid_IoU = valid_IoU / n
    log_value('valid_loss', mean_valid_loss, epoch)
    log_value('chamfer_distance', valid_CD, epoch)
    log_value('IoU', valid_IoU, epoch)

    print("Epoch {}/{} => valid_loss: {:f}, CD: {:f}, IoU: {:f}%".format(epoch, num_epochs, mean_valid_loss, valid_CD, valid_IoU*100))

    # save model
    if prev_CD > valid_CD:
        print("Saving the Model based on Chamfer Distance: %f"%(valid_CD), flush=True)
        torch.save(net.state_dict(), "trained_models/{}.pth".format(model_name))
        prev_CD = valid_CD

In [None]:
# Load the pretrained model
net.load_state_dict(torch.load("trained_models/{}.pth".format(model_name)))

Prediction:

In [None]:
#for cad dataset

boxes = []
circles = []
coverage_cir = []
triangles = []
coverage_tri = []
coverage_rect = []

for batch in valid_loader:
    #optimizer.zero_grad()
    
    #image, pt, dist, bounding_volume = batch
    image, pt, dist = batch

    image = image.to(device)
    pt = pt.to(device)
    dist = dist.to(device)

    recon, pred = net(image, pt, return_shapes_distances=True)

    pred = pred.permute((0, 2, 1)) #[batch, num_pt, num_shape]
    #print(len(pred))
    # generate random number between 0 and 128
    rand_num = random.randint(0, batch_size-1)

    gt_img = image[rand_num].detach().cpu()
    img = image[rand_num].cpu().detach().numpy()
    img = img.transpose((1, 2, 0))
    plt.imshow(img)
    plt.show()

    #pred = net.binarize(pred).reshape(-1, 64, 64)#.clone().cpu().numpy()
    #pred = pred.clamp(0, 1)
    pred = net.binarize(pred)

    #gt_dist = dist[rand_num].clone()
    #dist = dist.reshape(-1, 64, 64).clone().cpu().numpy()

    #plt.imshow(dist[0])
    #plt.show()

    recon = recon[rand_num].clone().detach().cpu().numpy()
    #print("recon shape: ",recon.shape)
    print("reconstructed:")
    plt.imshow(recon.reshape(64, 64), cmap="viridis")
    #pred_dist = pred_dist.reshape(-1, 64, 64)
    #plt.imshow(recon[0])
    plt.show()

    plt.figure(figsize=(8*2, 4*2))

    for i, prd in enumerate(pred[rand_num]):
        #print(prd.shape)

        #for dis in prd:
        p = prd.cpu().detach().numpy().reshape(64, 64)

        

        prd_m = torch.unsqueeze(prd.reshape(64, 64), 0)
        #print(prd_m.shape)

        coverage = coverage_threshold(gt_img.cpu().detach(), prd_m.cpu().detach())
        #print("coverage: ",coverage)
        coverage = coverage.item() if isinstance(coverage, torch.Tensor) else coverage

        plt.subplot(4, 8, i+1)
        plt.imshow(p, cmap="viridis")
        plt.axis("off")
        plt.title(round(coverage, 2))

        # if coverage < 0.4:
        #     continue

        contours, hierarchy = cv2.findContours((255*prd.cpu().detach().numpy().reshape(64, 64).copy()).astype('uint8'), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        #_,contours,_ = cv2.findContours(prd.cpu().detach().numpy().reshape(64, 64, 1).copy(), 1, 1)
        
        # if i <= 7:
        #     cir = cv2.minEnclosingCircle(contours[0])

        #     (x,y), radius = cir
        #     center = (int(x), int(y))
        #     radius = int(radius)
        #     circles.append((center, radius))
        # else:
        #     if coverage.item() < 0.9:
        #         continue

        if i <= 7:
            
            rect = cv2.minAreaRect(contours[0])

            (x,y),(w,h), a = rect
            box = cv2.boxPoints(rect)
            box = np.int0(box) #turn into ints
            boxes.append(box)

            coverage_rect.append(coverage)

        else:

            
            cir = cv2.minEnclosingCircle(contours[0])

            (x,y), radius = cir
            center = (int(x), int(y))
            radius = int(radius)
            circles.append((center, radius))
            coverage_cir.append(coverage)
            


    plt.tight_layout()
    plt.show()

    blank_img = np.ones((64, 64, 3)) * 255 #img.copy()#np.zeros((64, 64, 1))
    blank_img[:,:,0] = recon.reshape(64, 64)
    blank_img[:,:,1] = recon.reshape(64, 64)
    blank_img[:,:,2] = recon.reshape(64, 64)

    for bbox in boxes:
        cv2.drawContours(blank_img, [bbox], 0, (1, 0, 0), 1)

    for i, (cent, rad) in enumerate(circles):
        cv2.circle(blank_img, cent, rad, (0,1,0), 1)


    plt.figure(figsize=(12,12))
    outs = [img, recon.reshape(64, 64), blank_img]
    lbls = ["Input", "Reconstructed", "Primitives"]
    for i in range(3):
        plt.subplot(1,3,i+1) 
        plt.imshow(outs[i])
        #plt.xlabel(lbls[i])
        plt.title(lbls[i])
        plt.axis('off')
        
    plt.show()

    break