In [None]:
import os
import torch
import numpy as np
import sys
import open3d as o3d
import matplotlib.pyplot as plt

notebook_dir = os.getcwd()
parent_dir = os.path.dirname(notebook_dir)

sys.path.append(parent_dir)
# Then your import should work
from src.models.pointnetplusplus import PointNetPlusPlus
from src.configs.config import Config
config_dir = os.path.join(parent_dir, 'src', 'configs', 'default_config.yaml')
config = Config(config_dir)
def load_trained_model(checkpoint_path, num_classes=2, device="cpu"):
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    model = PointNetPlusPlus(num_classes=num_classes).to(device)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.eval()
    return model

def infer_on_new_data(model, points, device="cpu"):
    """
    Now we assume 'points' is already shaped (num_points,3) or at least 
    we've done any sampling/padding outside. So we just do the forward pass.
    """
    points_torch = torch.tensor(points, dtype=torch.float32, device=device).unsqueeze(0)
    model.eval()
    with torch.no_grad():
        outputs = model(points_torch)  # shape (1, N, num_classes)
        preds = outputs.argmax(dim=-1).squeeze(0)
    return preds.cpu().numpy()

def plot_combined_original_and_seg(
    orig_points, orig_colors,
    seg_points, seg_labels,
    out_file="combined_visual.png"
):
    """
    Creates a figure with 2 subplots side by side:
      - Subplot #1: "Downsampled Original" (with original color)
      - Subplot #2: Segmentation (plant=green, non-plant=red)
    Saves to out_file.
    Args:
        orig_points (N,3): sample of original coords
        orig_colors (N,3): sample of original colors in [0,1]
        seg_points (M,3): same or similar points used for segmentation
        seg_labels (M,): 0=non-plant, 1=plant
    """
    fig = plt.figure(figsize=(14, 6))
    ax1 = fig.add_subplot(1, 2, 1, projection='3d')
    ax2 = fig.add_subplot(1, 2, 2, projection='3d')

    # Subplot #1: Original color
    ax1.scatter(orig_points[:, 0],
                orig_points[:, 1],
                orig_points[:, 2],
                c=orig_colors,
                s=2)
    ax1.set_title("Downsampled Original")
    ax1.set_xlabel("X")
    ax1.set_ylabel("Y")
    ax1.set_zlabel("Z")

    # Subplot #2: Segmentation
    plant_mask = (seg_labels == 1)
    nonplant_mask = (seg_labels == 0)
    ax2.scatter(seg_points[nonplant_mask, 0],
                seg_points[nonplant_mask, 1],
                seg_points[nonplant_mask, 2],
                c='red', s=2, label='Non-Plant')
    ax2.scatter(seg_points[plant_mask, 0],
                seg_points[plant_mask, 1],
                seg_points[plant_mask, 2],
                c='green', s=2, label='Plant')
    ax2.set_title("Segmentation (Plant/Non-Plant)")
    ax2.set_xlabel("X")
    ax2.set_ylabel("Y")
    ax2.set_zlabel("Z")
    ax2.legend()

    plt.tight_layout()
    plt.savefig(out_file)
    plt.show()
    print(f"[INFO] Combined figure saved to {out_file}")


In [None]:
def visualize_points(points, colors):
    from mpl_toolkits.mplot3d import Axes3D
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')

    ax.scatter(points[:, 0],
               points[:, 1],
               points[:, 2],
               c=colors, s=1, alpha=0.5)

    
    #ax.set_title(title)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")

    plt.savefig('full_wheat_point_cloud.png')
    plt.show()


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_save_path = os.path.join(parent_dir, config.model.save_dir, 'best_model_train.pth')

# 1) load trained model
model = load_trained_model(
    checkpoint_path=model_save_path,
    num_classes=2,
    device=device
)

# 2) load a .ply cloud with color
file_name = os.path.join(parent_dir, 'data', 'wheat_data', 'raw', 'Wheat_Alsen_F8_2023-06-30-2013_fused_output.ply')
pcd = o3d.io.read_point_cloud(file_name)
points = np.asarray(pcd.points)   # shape (N,3)
colors = np.asarray(pcd.colors)   # shape (N,3) in [0,1]
visualize_points(points, colors)
# We'll do a uniform sampling step ourselves to keep 4096 points (if N>4096)
N = len(points)
if N > 4096:
    idx_sample = np.random.choice(N, 4096, replace=False)
    ds_points = points[idx_sample]
    ds_colors = colors[idx_sample]
else:
    ds_points = points
    ds_colors = colors

# 3) run inference using same ds_points
#    (Or you could do a separate chunk approach if needed.)
predicted_labels = infer_on_new_data(model, ds_points, device=device)
print("Predicted labels shape:", predicted_labels.shape)

# 4) Combine into side-by-side subplots
plot_combined_original_and_seg(
    ds_points, ds_colors,   # "Left" subplot: original color
    ds_points, predicted_labels,  # "Right" subplot: segmentation
    out_file="combined_visual.png"
)

In [None]:
import open3d as o3d

def open3d_visualize_segmentation(points, labels):
    """
    points: (N, 3)
    labels: (N,) -> 0=non-plant, 1=plant
    """
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)

    # Build colors array
    colors = np.zeros_like(points)  # shape (N,3)
    colors[labels == 1] = [0,1,0]   # plant -> green
    colors[labels == 0] = [1,0,0]   # non-plant -> red

    pcd.colors = o3d.utility.Vector3dVector(colors)
    o3d.visualization.draw_geometries([pcd], window_name="Plant vs. Non-Plant")

open3d_visualize_segmentation(ds_points, predicted_labels)