In [1]:
## Adapted from ExploringDBSCAN

In [2]:
## Imports
import torch
import numpy as np
import sklearn
import os, errno
import sys
from datetime import datetime
import time
import csv

import pandas as pd
from sklearn.cluster import DBSCAN
from sklearn.metrics import silhouette_score
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler

from Instseg_model import MultiLayerFastLocalGraphModelV2 as model1
from dataset import pcloader
from graph_generation import gen_multi_level_local_graph_v3

from math import floor, ceil
from scipy.stats import mode

import open3d as o3d
from plot_utils import add_staple_patch
from utils import parse_labelfile


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
## CLASSIFICATION INFERENCE

## High Level Config Settings

graph_gen_kwargs = {
	'add_rnd3d': True,
	'base_voxel_size': 0.8,
	'downsample_method': 'random',
	'level_configs': [
		{'graph_gen_kwargs': {'num_neighbors': 64, 'radius': 0.4},
		 'graph_gen_method': 'disjointed_rnn_local_graph_v3',
		 'graph_level': 0,
		 'graph_scale': 1},
		{'graph_gen_kwargs': {'num_neighbors': 192, 'radius': 1.2},
		 'graph_gen_method': 'disjointed_rnn_local_graph_v3',
		 'graph_level': 1,
		 'graph_scale': 1}]
}

def configure_model(model_params_path, max_cls_classes=3, max_inst_classes=7, verbose=False):
    a = time.time()
    model = model1(num_classes=max_cls_classes, max_instance_no=max_inst_classes)
    if os.path.isfile(model_params_path):
        model.load_state_dict(torch.load(model_params_path))
    else:
        if verbose:
            print(f"[ModelParamPathError] {model_params_path} does not exist")
        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), model_params_path)
    b = time.time()
    if verbose:
        print(f"Model Setup Time (secs) : {b-a}")
    return model

def classify_scan(model, scan_path, verbose=False):
    """
    Returns 
     - the x|y|z|cls|inst for all keypoints in the cloud as a Kx5 array
     - and also as a dict
    """
    a = time.time()
    pointxyz, offset = pcloader(scan_path)
    vertex_coord_list, keypoint_indices_list, edges_list = \
    gen_multi_level_local_graph_v3(pointxyz,0.6,graph_gen_kwargs['level_configs'])
    last_layer_v = vertex_coord_list[-1]

    ## conversions: type precision
    vertex_coord_list = [p.astype(np.float32) for p in vertex_coord_list]
    keypoint_indices_list = [e.astype(np.int32) for e in keypoint_indices_list]
    edges_list = [e.astype(np.int32) for e in edges_list]

    ## conversions: numpy array to tensor
    vertex_coord_list = [torch.from_numpy(item) for item in vertex_coord_list]
    keypoint_indices_list = [torch.from_numpy(item).long() for item in keypoint_indices_list]
    edges_list = [torch.from_numpy(item).long() for item in edges_list]

    ## Run graph through GNN model
    batch = (vertex_coord_list, keypoint_indices_list, edges_list)
    cls_seg, inst_seg = model(*batch)

    ## Filter classification probabilities for the most probable
    cls_preds = torch.argmax(cls_seg, dim=1)
    inst_preds = torch.argmax(inst_seg, dim=1)
    
    ## expand the shape of the array
    cls_preds = np.expand_dims(cls_preds, axis=1)
    inst_preds = np.expand_dims(inst_preds, axis=1)

    b = time.time()
    if verbose:
        print("Scan Inference Time (secs): ", b-a)
        print()
    
    return np.hstack((last_layer_v, cls_preds, inst_preds)),\
            {'vertices': last_layer_v, 'cls_preds': cls_preds, 'inst_preds': inst_preds}

def filter_out_background(scan_data):
    non_bg_idx = ~np.logical_or(scan_data[:, 3] == 0, scan_data[:, 4] == 0)
    non_bg = scan_data[non_bg_idx]
    return non_bg, non_bg_idx

def count_cluster_by_instance_prediction(scan_data, threshold_factor=0.5):
    """
    Returns a cluster_count ie the number of clusters on the cls field/col that has at least a threshold number of members
    Also: a dict of intermediary/final counts, types and thresholds
    """
    inst, inst_count = np.unique(scan_data[:,4], return_counts=True)
    inst_count_threshold = threshold_factor * np.mean(inst_count)
    reduced_idx = np.where(inst_count > inst_count_threshold)
    cluster_count = np.sum(inst_count > inst_count_threshold)
    reduced = inst[reduced_idx]
    
    return cluster_count, {'orig_insts': inst,
                           'orig_inst_ct': inst_count,
                           'inst_ct_thresh': inst_count_threshold,
                           'reduced_insts': reduced}
def draw_labels(welds, ax):
    for weld in welds:
        markers = add_staple_patch(ax ,weld['xloc'], weld['yloc'], weld["yaw"], weld['cls'] )
        
def _plot(path, points, values, title="", caption="", overlay_points=None, labels_dict=None ):
    """
    path - output path
    points - matrix of x, y, etc cols
    values - assigned value per point ie predicted/truth cls/inst
    [optional] overlay_points = [{"xs":[], "ys":[], "cs":_, "marker":_, "label"=_}, ...]
    labels_dict - to overlay the staple patch
    """
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_title(title)
    im = ax.scatter(points[:,0], points[:,1],s=0.25,c=values)
    

    if overlay_points:    
        for pts in overlay_points:
            _im = ax.scatter(pts["xs"], pts["ys"], s=pts["ss"], c=pts["cs"], marker=pts["marker"])
    
    if labels_dict:
        draw_labels(labels_dict["welds"], ax)
            
    ax.set_xlabel("x [mm]")
    ax.set_ylabel("y [mm]")

        
    legend_ = ax.legend(*im.legend_elements(), bbox_to_anchor=(1.1, 1), loc="upper right")
    ax.add_artist(legend_)

    ax.text(0.5, -0.5, caption, style='italic', \
        horizontalalignment='center', verticalalignment='top', transform=ax.transAxes)
    axes=plt.gca()
    axes.set_aspect(1)
    if path:
        plt.savefig(path, dpi=150)
        plt.close()
        plt.cla()
        plt.clf()
    else:
        plt.show
    
def save_prediction_plots(non_bg_matrix, labels_dict=None, cls_tag=None, inst_tag=None, cls_col=3, inst_col=4, inst_seg_dir_path="./plots/inst/", cls_seg_dir_path="./plots/cls/", verbose=False):
    """
    non_bg_matrix -- x|y|z|cls|inst
    // tag -- eg A_xycls_eps0_45_50
    tag -- eg A_xycls
    """
    
    ## [RED FLAG] - what if dir_path comes through as None
    f_tag = "[save_prediction_plots]"
    f_msg = []
    
    if not (cls_tag and inst_tag):
        print("[ERROR] No specfied cls_tag or inst_tag args. Plots not generated")
    else:
        if cls_tag:
            cls_output_path = cls_seg_dir_path+cls_tag+".png" if cls_seg_dir_path else None
            _plot(cls_output_path, non_bg_matrix[:, :2], non_bg_matrix[:, cls_col], title=cls_tag, labels_dict=labels_dict)
            f_msg.append(cls_output_path)

        if inst_tag:
            inst_output_path = inst_seg_dir_path+inst_tag+".png" if inst_seg_dir_path else None
            _plot(inst_output_path, non_bg_matrix[:, :2], non_bg_matrix[:, inst_col], title=inst_tag, labels_dict=labels_dict)
            f_msg.append(inst_output_path)

        if verbose:
            print(f"{f_tag}: Done --> {f_msg}")
            


###################################################
###  EVENTUALLY BUT SKIPPING THIS FOR NOW #########
## Need to record the model weights version
## Need to record the number of instances identified, the instance_seg_loss and cls_loss
def save_prediction_stats(model, scan_path, output_file):
    """
    Intialize text file if non-existent with name | predictions_clusters | prediction cluster counts | cluster_count_threshold | reduced_clusters 
    """
    pass

###################################################

def run_preclustering(model, scan_path, labels_dict=None, sample_tag="", inst_seg_img_dir_path="./plots/inst_seg/", cls_seg_img_dir_path="./plots/cls_seg/"):
    """
    Runs the scan through the model
    Filters out the background predictions
    Plots cls and inst predictions post filtering and saves plots to file
    Returns the Nx5 data of non-background points --> x|y|z|cls|inst
    """
    # KX5 array --> x|y|z|cls|inst
    scan_data, _ = classify_scan(model, scan_path, verbose=True)
    
    # non bg NX5 array --> x|y|z|cls|inst
    scan_data, _ = filter_out_background(scan_data)
    
    ## [RED FLAG] - No use of the cluster_count or data
    _, cluster_data = count_cluster_by_instance_prediction(scan_data)
    
    ## [RED FLAG] what if dir_path specified as None. Need to catch this cas
    save_prediction_plots(scan_data, cls_tag=sample_tag, inst_tag=sample_tag, cls_seg_dir_path=cls_seg_img_dir_path, inst_seg_dir_path=inst_seg_img_dir_path, labels_dict=labels_dict)
    return scan_data, {"0_scan": scan_path, "0_tag":sample_tag}

In [4]:
# ## Set up model
# myDetector = Detector(model_params="_model/train1-fix3/2023_06_30_09_56_12/params_epoch488_for_min_test_loss.pt")

# ## Looping through the test set
# ## TEMP: single sample
# scan_path = "./_data/scans/LH-3-231201600-Pass-2023_06_12-9-38-37-588.ply"
# result = myDetector.process_sample(scan_path)
# print(result)
# ## For each scan in test set, run the classification prediction
# ## Plot and save images for each 
# ## collate losses: --> scan name: instance seg loss, class seg loss


# ## For each scan in training set 

In [5]:
# ## Test Running a sample

# ## Run Settings
# model_params_path = "./_model/train1-fix3/2023_06_22_17_01_13/params_epoch497_for_min_test_loss.pt"
# scan_path = "./_data/scans/LH-3-231201600-Pass-2023_06_12-9-38-37-588.ply"
# sample_name = "LH-3-231201600-Pass-2023_06_12-9-38-37-588"

# ## Instantiating model and running scan through for predictions
# model = configure_model(model_params_path, verbose=True)
# scan_data, _ = classify_scan(model, scan_path, verbose=True)
# print("scan_data.shape :")
# print(scan_data.shape)
# print()


# ## Filter out background points
# scan_data, _ = filter_out_background(scan_data)
# print("scan_data.shape :")
# print(scan_data.shape)
# print()

# ## Clustering by the instance predictions
# cluster_count, cluster_data = count_cluster_by_instance_prediction(scan_data)
# print("cluster_count: ")
# print(cluster_count)
# print()
# print("cluster_data: ")
# print(cluster_data)
# print()

# # ## Plotting class predictions and instance predictions
# # save_prediction_plots(scan_data, cls_tag='egA_xycls', inst_tag='egA_xyinst')

# inst_seg_img_directory = "C:/Users/KZTYLF/Documents/playground/GNN UIs/GNN InstanceSegmentation/Recreating Dataset/_img_instance_seg/"
# cls_seg_img_directory = "C:/Users/KZTYLF/Documents/playground/GNN UIs/GNN InstanceSegmentation/Recreating Dataset/_img_class_seg/"
# save_prediction_plots(scan_data, cls_tag=sample_name, inst_tag=sample_name, cls_seg_dir_path=cls_seg_img_directory, inst_seg_dir_path=inst_seg_img_directory)

In [6]:
# ## Test Running a sample

# ## Run Settings
# model_params_path = "./_model/train1-fix3/2023_06_22_17_01_13/params_epoch497_for_min_test_loss.pt"
# scan_path = "./_data/scans/LH-3-231201600-Pass-2023_06_12-9-38-37-588.ply"
# sample_name = "LH-3-231201600-Pass-2023_06_12-9-38-37-588"

# ## Instantiating model and running scan through for predictions and plots on predictions
# model = configure_model(model_params_path, verbose=True)

# inst_seg_img_directory = "C:/Users/KZTYLF/Documents/playground/GNN UIs/GNN InstanceSegmentation/Recreating Dataset/_img_instance_seg/"
# cls_seg_img_directory = "C:/Users/KZTYLF/Documents/playground/GNN UIs/GNN InstanceSegmentation/Recreating Dataset/_img_class_seg/"
# run_preclustering(model, scan_path, sample_tag=sample_name, inst_seg_img_dir_path=inst_seg_img_directory, cls_seg_img_dir_path=cls_seg_img_directory)

In [8]:
## Running inference on all the training and test samples and plotting results

## Model Instantiating
model_params_path = "./_model/train1-fix3/2023_06_22_17_01_13/params_epoch497_for_min_test_loss.pt"
model = configure_model(model_params_path, verbose=True)


# ## Looping through the test samples to generate and save plots
# test_txt = "C:/Users/KZTYLF/Documents/playground/GNN UIs/GNN InstanceSegmentation/Recreating Dataset/_data/test.txt"
# test_samples = parse_dataset_file(test_txt)

# inst_seg_img_directory = "C:/Users/KZTYLF/Documents/playground/GNN UIs/GNN InstanceSegmentation/Recreating Dataset/_img_instance_seg/test_set_with_label_overlay/"
# cls_seg_img_directory = "C:/Users/KZTYLF/Documents/playground/GNN UIs/GNN InstanceSegmentation/Recreating Dataset/_img_class_seg/test_set_with_label_overlay/"

# for sample in test_samples:
#     scan_path = sample["scan_path"]
#     sample_name = sample["name"]
#     label_path = sample["label_path"]
#     labels_dict = parse_labelfile(label_path)
    
#     run_preclustering(model, scan_path, labels_dict=labels_dict, sample_tag=sample_name, inst_seg_img_dir_path=inst_seg_img_directory, cls_seg_img_dir_path=cls_seg_img_directory)

    
## Looping through the train samples to generate and save plots
train_txt = "C:/Users/KZTYLF/Documents/playground/GNN UIs/GNN InstanceSegmentation/Recreating Dataset/_data/train.txt"
train_samples = parse_dataset_file(train_txt)

inst_seg_img_directory = "C:/Users/KZTYLF/Documents/playground/GNN UIs/GNN InstanceSegmentation/Recreating Dataset/_img_instance_seg/train_set/"
cls_seg_img_directory = "C:/Users/KZTYLF/Documents/playground/GNN UIs/GNN InstanceSegmentation/Recreating Dataset/_img_class_seg/train_set/"

for sample in train_samples:
    scan_path = sample["scan_path"]
    sample_name = sample["name"]
#     labels_dict = parse_labelfile(label_path)
    
    run_preclustering(model, scan_path, sample_tag=sample_name, inst_seg_img_dir_path=inst_seg_img_directory, cls_seg_img_dir_path=cls_seg_img_directory)


Model Setup Time (secs) : 0.5708363056182861
Scan Inference Time (secs):  10.51757287979126

Scan Inference Time (secs):  10.537972688674927

Scan Inference Time (secs):  10.738929271697998

Scan Inference Time (secs):  10.024059772491455

Scan Inference Time (secs):  10.495957612991333

Scan Inference Time (secs):  9.566447734832764

Scan Inference Time (secs):  9.555433988571167

Scan Inference Time (secs):  8.909443855285645

Scan Inference Time (secs):  9.63809871673584

Scan Inference Time (secs):  9.579296112060547

Scan Inference Time (secs):  10.271457433700562

Scan Inference Time (secs):  10.109545469284058

Scan Inference Time (secs):  9.335759162902832

Scan Inference Time (secs):  10.822887420654297

Scan Inference Time (secs):  9.290156364440918

Scan Inference Time (secs):  9.180329084396362

Scan Inference Time (secs):  9.76867413520813

Scan Inference Time (secs):  9.844476222991943

Scan Inference Time (secs):  9.917199850082397

Scan Inference Time (secs):  9.7549750

Scan Inference Time (secs):  9.69490909576416

Scan Inference Time (secs):  9.257021427154541

Scan Inference Time (secs):  10.008171319961548

Scan Inference Time (secs):  10.716116428375244

Scan Inference Time (secs):  9.921858549118042

Scan Inference Time (secs):  10.966448545455933

Scan Inference Time (secs):  11.05726671218872

Scan Inference Time (secs):  11.185437202453613

Scan Inference Time (secs):  10.644825220108032

Scan Inference Time (secs):  9.851844310760498

Scan Inference Time (secs):  9.41592288017273

Scan Inference Time (secs):  9.634252071380615

Scan Inference Time (secs):  9.897159814834595

Scan Inference Time (secs):  9.717150449752808

Scan Inference Time (secs):  9.371872186660767

Scan Inference Time (secs):  9.813512325286865

Scan Inference Time (secs):  9.380998373031616

Scan Inference Time (secs):  9.491893291473389

Scan Inference Time (secs):  11.543368816375732

Scan Inference Time (secs):  10.116774082183838

Scan Inference Time (secs):  9.5020

Scan Inference Time (secs):  10.447551012039185

Scan Inference Time (secs):  9.663655757904053

Scan Inference Time (secs):  9.694123983383179

Scan Inference Time (secs):  10.20733380317688

Scan Inference Time (secs):  10.22365117073059

Scan Inference Time (secs):  10.060018062591553

Scan Inference Time (secs):  9.614469766616821

Scan Inference Time (secs):  10.41034722328186

Scan Inference Time (secs):  9.412639141082764

Scan Inference Time (secs):  9.73371410369873

Scan Inference Time (secs):  9.594569206237793

Scan Inference Time (secs):  9.70112943649292

Scan Inference Time (secs):  10.229320764541626

Scan Inference Time (secs):  9.212526798248291

Scan Inference Time (secs):  9.98763918876648

Scan Inference Time (secs):  9.619618892669678

Scan Inference Time (secs):  9.73998475074768

Scan Inference Time (secs):  8.77487063407898

Scan Inference Time (secs):  11.040887355804443

Scan Inference Time (secs):  9.928213834762573

Scan Inference Time (secs):  9.6045126914

<Figure size 640x480 with 0 Axes>

In [3]:
import numpy as np
a = np.arange(10)
np.where(a < 5)


(array([0, 1, 2, 3, 4], dtype=int64),)