In [1]:
## Imports
import torch
import numpy as np
import sklearn
import os, errno
import sys
from datetime import datetime
import time
import csv

import pandas as pd
from sklearn.cluster import DBSCAN
from sklearn.metrics import silhouette_score
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler

from Instseg_model import MultiLayerFastLocalGraphModelV2 as model1
from dataset import pcloader
from graph_generation import gen_multi_level_local_graph_v3

from math import floor, ceil
from scipy.stats import mode

import open3d as o3d
from plot_utils import add_staple_patch
from utils import parse_labelfile

ModuleNotFoundError: No module named 'Instseg_model'

In [None]:
## UTILS

## Parse txt files that list scan and label paths per row eg test.txt 
def parse_dataset_file(data_set_filepath):
    """
    Looks at the given text file that list scan and label path per row eg test.txt
    Returns:
     - a list of dicts, where each contains the "scan_path", the "label_path" and the sample "name"
    """
    
    samples = []
    with open(data_set_filepath, 'r') as setfile:
        lines = setfile.readlines()
        for line in lines:
            line = line.strip().strip("[]") # removing the [] bookending each line
            scan, label = line.split(",")
            scan_path = scan.strip().strip("'")
            label_path = label.strip().strip("'")
            name = label_path.split("/")[-1].strip(".txt")
            sample = {"scan_path": scan_path, "label_path": label_path, "name": name}
            samples.append(sample)
    return samples

def _plot(points, values, out_path=None, title="", caption="", overlay_points=None, labels_dict=None ):
    """
    path - output path
    points - matrix of x, y, etc cols
    values - assigned value per point ie predicted/truth cls/inst
    [optional] overlay_points = [{"xs":[], "ys":[], "cs":_, "marker":_, "label"=_}, ...]
    labels_dict - to overlay the staple patch
    """
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_title(title)
    im = ax.scatter(points[:,0], points[:,1],s=0.25,c=values)
    

    if overlay_points:    
        for pts in overlay_points:
            _im = ax.scatter(pts["xs"], pts["ys"], s=pts["ss"], c=pts["cs"], marker=pts["marker"])
    
    if labels_dict:
        draw_labels(labels_dict["welds"], ax)
            
    ax.set_xlabel("x [mm]")
    ax.set_ylabel("y [mm]")

        
    legend_ = ax.legend(*im.legend_elements(), bbox_to_anchor=(1.1, 1), loc="upper right")
    ax.add_artist(legend_)

    ax.text(0.5, -0.5, caption, style='italic', \
        horizontalalignment='center', verticalalignment='top', transform=ax.transAxes)
    axes=plt.gca()
    axes.set_aspect(1)
    if out_path:
        plt.savefig(out_path, dpi=150)
        plt.close()
        plt.cla()
        plt.clf()
    else:
        plt.show

In [None]:
## SEGMENTATION INFERENCE
## Run Segmentation

## High Level Graph Generation Config Settings
graph_gen_kwargs = {
	'add_rnd3d': True,
	'base_voxel_size': 0.8,
	'downsample_method': 'random',
	'level_configs': [
		{'graph_gen_kwargs': {'num_neighbors': 64, 'radius': 0.4},
		 'graph_gen_method': 'disjointed_rnn_local_graph_v3',
		 'graph_level': 0,
		 'graph_scale': 1},
		{'graph_gen_kwargs': {'num_neighbors': 192, 'radius': 1.2},
		 'graph_gen_method': 'disjointed_rnn_local_graph_v3',
		 'graph_level': 1,
		 'graph_scale': 1}]
}


def configure_model(model_params_path, max_cls_classes=3, max_inst_classes=7, verbose=False):
    a = time.time()
    model = model1(num_classes=max_cls_classes, max_instance_no=max_inst_classes)
    if os.path.isfile(model_params_path):
        model.load_state_dict(torch.load(model_params_path))
    else:
        if verbose:
            print(f"[ModelParamPathError] {model_params_path} does not exist")
        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), model_params_path)
    b = time.time()
    if verbose:
        print(f"Model Setup Time (secs) : {b-a}")
    return model

def classify_scan(model, scan_path, verbose=False):
    """
    Returns 
     - the x|y|z|cls|inst for all keypoints in the cloud as a Kx5 array
     - and also as a dict
    """
    a = time.time()
    pointxyz, offset = pcloader(scan_path)
    vertex_coord_list, keypoint_indices_list, edges_list = \
    gen_multi_level_local_graph_v3(pointxyz,0.6,graph_gen_kwargs['level_configs'])
    last_layer_v = vertex_coord_list[-1]

    ## conversions: type precision
    vertex_coord_list = [p.astype(np.float32) for p in vertex_coord_list]
    keypoint_indices_list = [e.astype(np.int32) for e in keypoint_indices_list]
    edges_list = [e.astype(np.int32) for e in edges_list]

    ## conversions: numpy array to tensor
    vertex_coord_list = [torch.from_numpy(item) for item in vertex_coord_list]
    keypoint_indices_list = [torch.from_numpy(item).long() for item in keypoint_indices_list]
    edges_list = [torch.from_numpy(item).long() for item in edges_list]

    ## Run graph through GNN model
    batch = (vertex_coord_list, keypoint_indices_list, edges_list)
    cls_seg, inst_seg = model(*batch)

    ## Filter classification probabilities for the most probable
    cls_preds = torch.argmax(cls_seg, dim=1)
    inst_preds = torch.argmax(inst_seg, dim=1)
    
    ## expand the shape of the array
    cls_preds = np.expand_dims(cls_preds, axis=1)
    inst_preds = np.expand_dims(inst_preds, axis=1)

    b = time.time()
    if verbose:
        print("Scan Inference Time (secs): ", b-a)
        print()
    
    return np.hstack((last_layer_v, cls_preds, inst_preds)),\
            {'vertices': last_layer_v, 'cls_preds': cls_preds, 'inst_preds': inst_preds}

def filter_out_background(scan_data):
    non_bg_idx = ~np.logical_or(scan_data[:, 3] == 0, scan_data[:, 4] == 0)
    non_bg = scan_data[non_bg_idx]
    return non_bg, non_bg_idx

# def count_cluster_by_instance_prediction(scan_data, threshold_factor=0.5):
#     """
#     Returns a cluster_count ie the number of clusters on the cls field/col that has at least a threshold number of members
#     Also: a dict of intermediary/final counts, types and thresholds
#     """
#     inst, inst_count = np.unique(scan_data[:,4], return_counts=True)
#     inst_count_threshold = threshold_factor * np.mean(inst_count)
#     reduced_idx = np.where(inst_count > inst_count_threshold)
#     cluster_count = np.sum(inst_count > inst_count_threshold)
#     reduced = inst[reduced_idx]
    
#     return cluster_count, {'orig_insts': inst,
#                            'orig_inst_ct': inst_count,
#                            'inst_ct_thresh': inst_count_threshold,
#                            'reduced_insts': reduced}
# def draw_labels(welds, ax):
#     for weld in welds:
#         markers = add_staple_patch(ax ,weld['xloc'], weld['yloc'], weld["yaw"], weld['cls'] )
        
# def _plot(path, points, values, title="", caption="", overlay_points=None, labels_dict=None ):
#     """
#     path - output path
#     points - matrix of x, y, etc cols
#     values - assigned value per point ie predicted/truth cls/inst
#     [optional] overlay_points = [{"xs":[], "ys":[], "cs":_, "marker":_, "label"=_}, ...]
#     labels_dict - to overlay the staple patch
#     """
    
#     fig = plt.figure()
#     ax = fig.add_subplot(111)
#     ax.set_title(title)
#     im = ax.scatter(points[:,0], points[:,1],s=0.25,c=values)
    

#     if overlay_points:    
#         for pts in overlay_points:
#             _im = ax.scatter(pts["xs"], pts["ys"], s=pts["ss"], c=pts["cs"], marker=pts["marker"])
    
#     if labels_dict:
#         draw_labels(labels_dict["welds"], ax)
            
#     ax.set_xlabel("x [mm]")
#     ax.set_ylabel("y [mm]")

        
#     legend_ = ax.legend(*im.legend_elements(), bbox_to_anchor=(1.1, 1), loc="upper right")
#     ax.add_artist(legend_)

#     ax.text(0.5, -0.5, caption, style='italic', \
#         horizontalalignment='center', verticalalignment='top', transform=ax.transAxes)
#     axes=plt.gca()
#     axes.set_aspect(1)
#     if path:
#         plt.savefig(path, dpi=150)
#         plt.close()
#         plt.cla()
#         plt.clf()
#     else:
#         plt.show
    
# def save_prediction_plots(non_bg_matrix, labels_dict=None, cls_tag=None, inst_tag=None, cls_col=3, inst_col=4, inst_seg_dir_path="./plots/inst/", cls_seg_dir_path="./plots/cls/", verbose=False):
#     """
#     non_bg_matrix -- x|y|z|cls|inst
#     // tag -- eg A_xycls_eps0_45_50
#     tag -- eg A_xycls
#     """
    
#     ## [RED FLAG] - what if dir_path comes through as None
#     f_tag = "[save_prediction_plots]"
#     f_msg = []
    
#     if not (cls_tag and inst_tag):
#         print("[ERROR] No specfied cls_tag or inst_tag args. Plots not generated")
#     else:
#         if cls_tag:
#             cls_output_path = cls_seg_dir_path+cls_tag+".png" if cls_seg_dir_path else None
#             _plot(cls_output_path, non_bg_matrix[:, :2], non_bg_matrix[:, cls_col], title=cls_tag, labels_dict=labels_dict)
#             f_msg.append(cls_output_path)

#         if inst_tag:
#             inst_output_path = inst_seg_dir_path+inst_tag+".png" if inst_seg_dir_path else None
#             _plot(inst_output_path, non_bg_matrix[:, :2], non_bg_matrix[:, inst_col], title=inst_tag, labels_dict=labels_dict)
#             f_msg.append(inst_output_path)

#         if verbose:
#             print(f"{f_tag}: Done --> {f_msg}")
            


# ###################################################
# ###  EVENTUALLY BUT SKIPPING THIS FOR NOW #########
# ## Need to record the model weights version
# ## Need to record the number of instances identified, the instance_seg_loss and cls_loss
# def save_prediction_stats(model, scan_path, output_file):
#     """
#     Intialize text file if non-existent with name | predictions_clusters | prediction cluster counts | cluster_count_threshold | reduced_clusters 
#     """
#     pass

# ###################################################

def run_preclustering(model, scan_path, labels_dict=None, sample_tag="", inst_seg_img_dir_path="./plots/inst_seg/", cls_seg_img_dir_path="./plots/cls_seg/"):
    """
    Runs the scan through the model
    Filters out the background predictions
    [OMITTED - Plots cls and inst predictions post filtering and saves plots to file]
    Returns the Nx5 data of non-background points --> x|y|z|cls|inst
    """
    # KX5 array --> x|y|z|cls|inst
    scan_data, _ = classify_scan(model, scan_path, verbose=True)
    
    # non bg NX5 array --> x|y|z|cls|inst
    scan_data, _ = filter_out_background(scan_data)
    
    return scan_data

In [None]:
## CLUSTERING INSTANCE SEGMENTATION RESULTS INTO INDIVIDUAL WELDS

## dbscan params
eps = 0
min_samples= 0

## 
def partial_isolate_and_cluster(scan_data, inst_seg_col, inst_seg_value, dbscan_model, view_per_segment=False):
    partial = scan_data[scan_data[:, inst_seg_col] == inst_seg_value]
    DBSCAN_result = dbscan_model.fit_predict(partial[:, :2])
    
    ## Visualizing raw dbscan result
    if view_per_segment:
        _plot(partial, DBSCAN_result, caption=f"Raw DBSCAN clusters for inst_segment={inst_seg_value}")

    ## Clean up routine, then visualizing
    # Get only those in the largest cluster
    cluster_vals, cluster_counts = np.unique(DBSCAN_result, return_counts=True)
    idx = np.argmax(cluster_counts)
    val = cluster_vals[idx]
#     print("cluster_vals", cluster_vals)
#     print("cluster_counts", cluster_counts)
#     print("idx:", idx)
    partial = partial[DBSCAN_result == val]
    values = DBSCAN_result[DBSCAN_result == val]
    
    if view_per_segment:
        _plot(partial, values, caption=f"Filtered DBSCAN clusters for inst_segment={inst_seg_value}")
    
    return partial
    
    
    
def full_isolate_and_cluster(scan_data, tag, inst_seg_col=4, eps=2, min_samples=20, view_per_segment=False, view_full=False, out_dir=None):
    
    ## permutations on the min_samples and 
    if not isinstance(eps, (tuple, list)):
        eps = (eps,)
        
    if not isinstance(min_samples, (tuple, list)):
        min_samples = (min_samples,)
    
    combos = []
    
    for i in eps:
        for j in min_samples:
            combos.append((i, j))
    print(combos)
    
    for eps, min_samples in combos:
        ## Construct the DBSCAN model
        DBSCAN_model = DBSCAN(eps=eps, min_samples=min_samples)

        ## Loop through all instance
        insts = np.unique(scan_data[:, inst_seg_col])
        partials = []

        for inst_seg_value in insts:
            print("inst_seg_value", inst_seg_value)
            partial = partial_isolate_and_cluster(scan_data, inst_seg_col, inst_seg_value, DBSCAN_model, view_per_segment=view_per_segment)
            partials.append(partial)

        ## Aggregating all partials
        full = np.vstack(partials)

        if view_full:
            if out_dir:
                _plot(scan_data, scan_data[:, inst_seg_col], caption=f"PreDBSCAN", out_path=out_dir+tag+"_rawSeg.png")
                _plot(full, full[:, inst_seg_col], caption=f"Filtered FULL DBSCAN Result_eps{eps}_min{min_samples}", out_path=out_dir+tag+f"_eps{eps}_minsamples{min_samples}.png",)
            else:
                _plot(scan_data, scan_data[:, inst_seg_col], caption=f"PreDBSCAN")
                _plot(full, full[:, inst_seg_col], caption=f"Filtered FULL DBSCAN Result_eps{eps}_min{min_samples}")

    return full


## RECOMBINING INTO A POST-CLUSTERING AGGREGATE

In [None]:
## Testing New DBSCAN sequence

##########
## TestSet
##########

model_params_path = "./_model/train1-fix3/2023_06_22_17_01_13/params_epoch497_for_min_test_loss.pt"
dataset_txt = "./Recreating Dataset/_data/test.txt"
img_dir = "./Recreating Dataset/_img_dbscan_on_insts/test/"
trash_dir = "./Recreating Dataset/_img_dbscan_on_insts/trash/"
eps = (2,3,4,5)
min_samples = (20, 30, 40, 50)

# ##########
# ## TrainSet
# ##########
# dataset_txt = "./Recreating Dataset/_data/train.txt"
# img_dir = "./Recreating Dataset/_img_dbscan_on_insts/train/"


def test_DBSCAN_onInstSegResults(model_params_path, dataset_txt, out_dir, trash_dir, eps=2, min_samples=20):
    
    ## Instantiating Model
    seg_model = configure_model(model_params_path, verbose=True)
    
    ## Looping through datset
    samples = parse_dataset_file(dataset_txt)[]
    for sample in samples[:2]:
        scan_path = sample["scan_path"]
        sample_name = sample["name"]
        print(scan_path, sample_name)
        
        ## Segmentation
        scan_data = run_preclustering(seg_model, scan_path, sample_tag=sample_name, inst_seg_img_dir_path=trash_dir, cls_seg_img_dir_path=trash_dir)
        
        ## DBSCAN on Instance Seg
        scan_data = full_isolate_and_cluster(scan_data, sample_name, 4, eps=eps, min_samples=min_samples, view_per_segment=True, view_full=True, out_dir=out_dir)

test_DBSCAN_onInstSegResults(model_params_path, dataset_txt, img_dir, trash_dir, eps=eps, min_samples=min_samples)