In [1]:
from glob import glob
from src.systems.cartpole import CartPoleSystem
from src.flow_matching.cartpole.latent_conditional.flow_matcher import CartPoleLatentConditionalFlowMatcher
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
import os
import torch

data_dir = "/common/users/shared/pracsys/genMoPlan/data_trajectories/cartpole_pybullet"
roa_file = "/common/users/shared/pracsys/genMoPlan/data_trajectories/cartpole_pybullet/roa_labels.txt"
bounds_file = "/common/users/dm1487/arcmg_datasets/cartpole_pybullet/cartpole_pybullet_data_bounds.pkl"


system = CartPoleSystem(bounds_file=bounds_file, use_dynamic_bounds=True)

ckpt_path = "/common/home/dm1487/robotics_research/tripods/olympics-classifier/outputs/cartpole_old_100_manifold/2025-11-11_03-12-53"
flow_matcher = CartPoleLatentConditionalFlowMatcher.load_from_checkpoint(ckpt_path, device="cuda:0")

Loading CartPole bounds from: /common/users/dm1487/arcmg_datasets/cartpole_pybullet/cartpole_pybullet_data_bounds.pkl
Use dynamic bounds: True
Path exists: True
  [0] Cart position (x): [-5.648, 5.843] -> limit: ¬±5.843
  [1] Pole angle (Œ∏): [-3.142, 3.142] -> WRAPPED to ¬±œÄ
  [2] Cart velocity (·∫ã): [-6.859, 7.039] -> limit: ¬±7.039
  [3] Angular velocity (Œ∏Ãá): [-8.039, 7.946] -> limit: ¬±8.039
Loaded CartPole bounds from: /common/users/dm1487/arcmg_datasets/cartpole_pybullet/cartpole_pybullet_data_bounds.pkl
üìÅ Folder provided: /common/home/dm1487/robotics_research/tripods/olympics-classifier/outputs/cartpole_old_100_manifold/2025-11-11_03-12-53
üîç Searching for checkpoint in folder...
   ‚úì Found best checkpoint (val_loss=0.0937)
   üìÑ Using: epoch257-val_loss0.0937.ckpt
ü§ñ Loading CartPole LCFM checkpoint: /common/home/dm1487/robotics_research/tripods/olympics-classifier/outputs/cartpole_old_100_manifold/2025-11-11_03-12-53/version_0/checkpoints/epoch257-val_loss0.093

In [2]:
roa_data = np.loadtxt(roa_file, delimiter=",")
inp, labels = roa_data[:, :-1], roa_data[:, -1]
inp = torch.from_numpy(inp).float().to("cuda:0")
labels = torch.from_numpy(labels).long().to("cuda:0")
np.mean(roa_data[:, -1] == 1)


0.17968548373221382

In [3]:
from tqdm import tqdm

samples = 100
repeats = 1
batch_size = 2048  # You can tune this depending on memory
success_threshold = 0.6
failure_threshold = 0.4
tp = 0
tn = 0
fp = 0
fn = 0

sep_count = 0

start_idx = 0
stop_idx = len(roa_data)

is_success = np.zeros((len(roa_data), samples, repeats))
for batch_start in tqdm(range(start_idx, stop_idx, batch_size)):
    batch_end = min(batch_start + batch_size, stop_idx)
    batch_inp = inp[batch_start:batch_end, :]

    # Will be (batch_size, samples, repeats)
    for sample_idx in range(samples):
        model_input = batch_inp.clone()
        for repeat_idx in range(repeats):
            pred = flow_matcher.predict_endpoint(model_input)
            # pred shape: (batch_size, d)
            is_success[batch_start:batch_end, sample_idx, repeat_idx] = system.classify_attractor(pred, 0.2).cpu().numpy()
            # is_success[batch_start:batch_end, sample_idx, repeat_idx] = pred[:, 21].cpu().numpy() > 1.3
            model_input = pred.clone()
            

# is_success_mean = is_success.mean(axis=(1,2))
# pred_success = (is_success_mean > success_threshold)
# pred_failure = (is_success_mean < failure_threshold) 

# # Compute tp, tn, fp, fn
# batch_labels = labels[start_idx:stop_idx].cpu().numpy()
# tp = np.sum((batch_labels == 1) & pred_success)
# fp = np.sum((batch_labels == 0) & pred_success)
# fn = np.sum((batch_labels == 1) & pred_failure)
# tn = np.sum((batch_labels == 0) & pred_failure)
# sep_count = np.sum((is_success_mean <= success_threshold) & (is_success_mean >= failure_threshold))
#     # # Success logic: check along [samples, repeats] for each data point in batch
#     # # We'll take the mean across all samples and repeats for head height > 1.3
    
#     # mean_success = (all_head_heights > 1.3).mean(axis=(1,2))
#     # is_success = mean_success > success_threshold  # shape: (batch_size,)
    
#     # is_failure = mean_success < failure_threshold  # shape: (batch_size,)
    
#     # sep_count += np.sum((mean_success <= success_threshold) & (mean_success >= failure_threshold))

#     # batch_labels = labels[batch_start:batch_end].cpu().numpy()  # shape: (batch_size,)

#     # tp += np.sum((batch_labels == 1) & (is_success))
#     # fp += np.sum((batch_labels == 0) & (is_success))
#     # fn += np.sum((batch_labels == 1) & (is_failure))  
#     # tn += np.sum((batch_labels == 0) & (is_failure))

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 57/57 [09:51<00:00, 10.39s/it]


In [4]:
batch_labels = labels[start_idx:stop_idx].cpu().numpy()
pred_labels = np.ones_like(batch_labels) * -1

failure = (is_success == -1).sum(axis=(1, 2))/samples > 0.6
success = (is_success == 1).sum(axis=(1, 2))/samples > 0.6

pred_labels[failure] = 0
pred_labels[success] = 1

In [5]:
tp = np.sum((batch_labels == 1) & (pred_labels == 1))
tn = np.sum((batch_labels == 0) & (pred_labels == 0))
fp = np.sum((batch_labels == 0) & (pred_labels == 1))
fn = np.sum((batch_labels == 1) & (pred_labels == 0))


In [6]:
precision = tp / (tp + fp) if tp + fp > 0 else 0
recall = tp / (tp + fn) if tp + fn > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
specificity = tn / (tn + fp) if tn + fp > 0 else 0
sep_perc = sep_count/len(roa_data)
print(f"Precision: {precision}, Recall: {recall}, F1: {f1}, Specificity: {specificity}, Sep: {sep_perc}")

# confusion matrix
conf_mat = np.zeros((2, 2))
conf_mat[0, 0] = tp
conf_mat[0, 1] = fp
conf_mat[1, 0] = fn
conf_mat[1, 1] = tn


Precision: 0, Recall: 0, F1: 0, Specificity: 1.0, Sep: 0.0


In [7]:
np.sum(pred_labels == -1) / len(roa_data)

0.9676020715404071