In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
from model import SingleImageTransformerCLIP
from dataset import BarLinkageDataset 
import matplotlib.pyplot as plt
torch.set_float32_matmul_precision('medium')

from curve_plot import get_pca_inclination, rotate_curve
import scipy.spatial.distance as sciDist
from tqdm import tqdm
import requests
import time
import matplotlib.pyplot as plt
import os
import json
import torch.nn.functional as F

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
# Headless simulator version
index = 0 # local server index 
API_ENDPOINT = f"http://localhost:4000/simulation"
HEADERS = {"Content-Type": "application/json"}
speedscale = 1
steps = 360
minsteps = int(steps*20/360)

In [None]:
checkpoint_path = "weights/transformer_weights_17/d512_h8_n6_bs512_lr0.0001_best.pth"
data_dir = "/home/anurizada/Documents/processed_dataset_17"
batch_size = 1

dataset = BarLinkageDataset(data_dir=data_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
checkpoint = torch.load(checkpoint_path, map_location=device)
model_config = checkpoint['model_config']

# Initialize model
model = SingleImageTransformerCLIP(
    tgt_seq_len=model_config['tgt_seq_len'],
    d_model=model_config['d_model'],
    h=model_config['h'],
    N=model_config['N'],
    num_labels=model_config['num_labels'],
    vocab_size=model_config['vocab_size'] + 1,
).to(device)

# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

In [None]:
import os
import json
import time
import numpy as np
import torch
import torch.nn.functional as F
import requests
from tqdm import tqdm
import matplotlib.pyplot as plt

# If not already imported in your session:
from curve_plot import get_pca_inclination, rotate_curve

# ===================================
# CONFIGURATION
# ===================================
label_mapping_path = "/home/anurizada/Documents/processed_dataset_17/label_mapping.json"
coupler_mapping_path = "/home/anurizada/Documents/transformer/BSIdict.json"  # coupler info

# --- load mappings ---
with open(label_mapping_path, "r") as f:
    label_mapping = json.load(f)
index_to_label = label_mapping["index_to_label"]              # {"0": "RRRR", ...}
label_to_index = {v: int(k) for k, v in index_to_label.items()}  # {"RRRR": 0, ...}

with open(coupler_mapping_path, "r") as f:
    coupler_mapping = json.load(f)  # {"RRRR": {"c": [0,0,0,1,...]}, ...}

# Limit to mech types 0..16, as requested
MECH_MIN, MECH_MAX = 0, 16
mechanism_types = [index_to_label[str(i)] for i in range(MECH_MIN, MECH_MAX + 1)]

# --- coordinate binning setup ---
class CoordinateBinner:
    def __init__(self, kappa=1.0, num_bins=200):
        self.kappa = kappa
        self.num_bins = num_bins
        self.bin_edges = np.linspace(-kappa, kappa, num_bins + 1)
        self.bin_centers = (self.bin_edges[:-1] + self.bin_edges[1:]) / 2

    def bin_to_value_torch(self, bin_index_tensor):
        bin_index_tensor = torch.clamp(bin_index_tensor, 0, self.num_bins - 1)
        device = bin_index_tensor.device
        bin_centers_tensor = torch.tensor(self.bin_centers, device=device, dtype=torch.float32)
        return bin_centers_tensor[bin_index_tensor]

NUM_BINS = label_mapping["num_bins"]
BIN_OFFSET = 3  # (SOS, EOS, PAD) offset
binner = CoordinateBinner(kappa=1.0, num_bins=NUM_BINS)

print(f"Loaded {len(index_to_label)} mechanism types (using 0..16).")
print(f"Loaded coupler mapping with {len(coupler_mapping)} entries.")
print("Started")

start_time = time.time()
sos_token, eos_token, pad_token = 0, 1, 2

# You must define these elsewhere (you already had them):
# API_ENDPOINT, HEADERS, speedscale, steps, minsteps

# ===================================
# HELPERS
# ===================================
def coupler_index_for(mech_type: str) -> int:
    """Return coupler curve index from BSIdict.json; -1 if unavailable."""
    if mech_type in coupler_mapping and "c" in coupler_mapping[mech_type]:
        cvec = coupler_mapping[mech_type]["c"]
        if isinstance(cvec, list) and 1 in cvec:
            return cvec.index(1)
    return -1

def bins_to_continuous(seq, binner, bin_offset):
    """Convert a 1D array of vocab indices into continuous coords via binner."""
    seq = np.array(seq, dtype=np.int64)
    numeric_mask = seq >= bin_offset
    seq_numeric = seq[numeric_mask] - bin_offset
    if seq_numeric.size == 0:
        return np.array([], dtype=np.float32)
    seq_tensor = torch.tensor(seq_numeric, dtype=torch.long)
    seq_cont = binner.bin_to_value_torch(seq_tensor).cpu().numpy()
    return seq_cont

def predict_single_autoreg(model, image_tensor, mech_label_idx, max_seq_len, device,
                           top_k=None, temperature=1.0):
    """
    Autoregressive decoding for a single image & single mech label (integer ID).
    Greedy if top_k is None; else top-k sampling with the provided k.
    Returns np.array of predicted token ids.
    """
    model.eval()
    with torch.no_grad():
        # Ensure inputs on correct device; keep original tensors unmodified
        image_1 = image_tensor.unsqueeze(0).to(device, non_blocking=True)
        label_1 = torch.tensor([int(mech_label_idx)], device=device, dtype=torch.long)

        # Start with SOS
        decoder_input = torch.full((1, 1), sos_token, device=device, dtype=torch.long)
        pred = []

        for _ in range(max_seq_len):
            T = decoder_input.size(1)
            causal_mask = (torch.triu(torch.ones(T, T, device=device)) == 1).T  # [T,T]

            logits, _, _ = model(decoder_input, causal_mask, image_1, label_1)
            next_logits = logits[:, -1, :] / float(temperature)
            probs = F.softmax(next_logits, dim=-1)

            if top_k is None:
                next_token = torch.argmax(probs, dim=-1)          # greedy
            else:
                k = min(int(top_k), probs.size(-1))
                topk_probs, topk_idx = torch.topk(probs, k=k, dim=-1)
                sampled = torch.multinomial(topk_probs, num_samples=1)  # [1,1]
                next_token = topk_idx.gather(-1, sampled).squeeze(1)    # [1]

            token = int(next_token.item())
            pred.append(token)
            decoder_input = torch.cat([decoder_input, next_token.unsqueeze(1)], dim=1)

            if token == eos_token:
                break

    return np.array(pred, dtype=np.int64)

# ===================================
# ONE-BY-ONE INFERENCE OVER DATALOADER
# ===================================
device = next(model.parameters()).device  # use model's device (avoids cuda:0/cuda:1 mismatches)
max_samples = 10
# take model-configured length if present; otherwise derive from dataset targets
model_tgt_len = getattr(model, "model_config", {}).get("tgt_seq_len", None)

samples_processed = 0

for batch in tqdm(dataloader, desc="One-by-one inference"):
    if samples_processed >= max_samples:
        break

    # We process one sample at a time explicitly:
    images = batch["images"]              # (B, C, H, W)
    labels_enc = batch["encoded_labels"]  # (B,) int label IDs (as you stated)
    targets_discrete = batch["labels_discrete"]  # (B, T)

    B = images.size(0)
    for bi in range(B):
        if samples_processed >= max_samples:
            break

        # ----- Prepare this single example -----
        image = images[bi].cpu()  # keep on CPU; predict function will move to device
        gt_label_idx = int(labels_enc[bi].item())  # you said it's integer-coded, not one-hot
        gt_mech_type = index_to_label[str(gt_label_idx)]
        out_dir = f"results/sample_{samples_processed:03d}"
        os.makedirs(out_dir, exist_ok=True)

        # Prepare GT joints (for simulation of the reference/original curve)
        tgt_seq = targets_discrete[bi].cpu().numpy()
        tgt_seq = tgt_seq[tgt_seq != pad_token]  # strip PAD
        # If EOS appears in target, you can truncate here (optional)
        # eos_pos = np.where(tgt_seq == eos_token)[0]
        # if eos_pos.size > 0: tgt_seq = tgt_seq[:eos_pos[0]]

        tgt_cont = bins_to_continuous(tgt_seq, binner, BIN_OFFSET)
        if tgt_cont.size % 2 == 1:
            tgt_cont = tgt_cont[:-1]
        if tgt_cont.size == 0:
            samples_processed += 1
            continue

        gt_joints = tgt_cont.reshape(-1, 2)
        gt_points = [gt_joints[j].tolist() for j in range(gt_joints.shape[0])]

        # ----- Simulate the GT curve (red) -----
        ex_gt = {
            "params": gt_points,
            "type": gt_mech_type,
            "speedScale": speedscale,
            "steps": steps,
            "relativeTolerance": 0.1,
        }
        try:
            temp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps([ex_gt])).json()
            P = np.array(temp[0]["poses"]) if isinstance(temp, list) and temp and "poses" in temp[0] else None
        except Exception:
            samples_processed += 1
            continue
        if P is None or P.shape[0] < minsteps:
            samples_processed += 1
            continue

        coup_idx_gt = coupler_index_for(gt_mech_type)
        original_x, original_y = P[:, coup_idx_gt, 0], P[:, coup_idx_gt, 1]
        orig_phi = -get_pca_inclination(original_x, original_y)
        orig_denom = np.sqrt(np.var(original_x) + np.var(original_y)) + 1e-8
        ox_mean, oy_mean = np.mean(original_x), np.mean(original_y)

        # ----- Decode length: prefer model-configured length; fall back to target length -----
        max_seq_len = int(model_tgt_len) if model_tgt_len is not None else int(targets_discrete.size(1))

        # ----- Try mech types 0..16 one by one -----
        for mech_idx in range(MECH_MIN, MECH_MAX + 1):
            mech_type = index_to_label[str(mech_idx)]

            # Predict sequence for this mech label
            pred_seq = predict_single_autoreg(
                model=model,
                image_tensor=image,                    # moved to device inside
                mech_label_idx=mech_idx,               # integer label id
                max_seq_len=max_seq_len,
                device=device,
                top_k=True,                            # greedy by default; set an int for sampling
                temperature=1.0
            )
            
            print(pred_seq)
            
            # Convert to continuous coords
            pred_cont = bins_to_continuous(pred_seq, binner, BIN_OFFSET)
            if pred_cont.size % 2 == 1:
                pred_cont = pred_cont[:-1]
            if pred_cont.size == 0:
                continue

            pred_joints = pred_cont.reshape(-1, 2)
            pred_points = [pred_joints[j].tolist() for j in range(pred_joints.shape[0])]

            # Coupler index for predicted mech
            coup_idx_pred = coupler_index_for(mech_type)

            # Simulate predicted curve (green)
            ex_pred = {
                "params": pred_points,
                "type": mech_type,
                "speedScale": speedscale,
                "steps": steps,
                "relativeTolerance": 0.1,
            }
            try:
                temp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps([ex_pred])).json()
                Pp = np.array(temp[0]["poses"]) if isinstance(temp, list) and temp and "poses" in temp[0] else None
            except Exception:
                continue
            if Pp is None or Pp.shape[0] < minsteps:
                continue

            generated_x, generated_y = Pp[:, coup_idx_pred, 0], Pp[:, coup_idx_pred, 1]
            if np.isnan(generated_x).any() or np.isinf(generated_x).any() or len(generated_x) < 30:
                continue

            # Align predicted curve to GT
            gen_phi = -get_pca_inclination(generated_x, generated_y)
            rotation = gen_phi - orig_phi
            generated_x, generated_y = rotate_curve(generated_x, generated_y, rotation)

            gen_denom = np.sqrt(np.var(generated_x) + np.var(generated_y)) + 1e-8
            scale = orig_denom / gen_denom
            generated_x *= scale
            generated_y *= scale

            gx_mean, gy_mean = np.mean(generated_x), np.mean(generated_y)
            generated_x -= (gx_mean - ox_mean)
            generated_y -= (gy_mean - oy_mean)

            # Plot both curves
            plt.plot(original_x, original_y, "r", label=f"GT ({gt_mech_type})")
            plt.plot(generated_x, generated_y, "g", label=f"Pred: {mech_type}")
            plt.axis("equal")
            plt.legend()
            plt.title(f"Sample {samples_processed} | GT={gt_mech_type} | Pred={mech_type}")
            plt.savefig(os.path.join(out_dir, f"{mech_type}.jpg"))
            plt.clf()

        samples_processed += 1

print(f"✅ Finished all samples in {time.time() - start_time:.2f} seconds")



In [None]:
# import os
# import json
# import time
# import numpy as np
# import torch
# import torch.nn.functional as F
# import requests
# from tqdm import tqdm
# import matplotlib.pyplot as plt

# # ===================================
# # CONFIGURATION
# # ===================================
# label_mapping_path = "/home/anurizada/Documents/processed_dataset_17/label_mapping.json"
# coupler_mapping_path = "/home/anurizada/Documents/transformer/BSIdict.json"

# with open(label_mapping_path, "r") as f:
#     label_mapping = json.load(f)
# index_to_label = label_mapping["index_to_label"]

# with open(coupler_mapping_path, "r") as f:
#     coupler_mapping = json.load(f)

# # --- coordinate binning setup ---
# class CoordinateBinner:
#     def __init__(self, kappa=1.0, num_bins=200):
#         self.kappa = kappa
#         self.num_bins = num_bins
#         self.bin_edges = np.linspace(-kappa, kappa, num_bins + 1)
#         self.bin_centers = (self.bin_edges[:-1] + self.bin_edges[1:]) / 2

#     def bin_to_value_torch(self, bin_index_tensor):
#         bin_index_tensor = torch.clamp(bin_index_tensor, 0, self.num_bins - 1)
#         bin_centers_tensor = torch.tensor(self.bin_centers, device=bin_index_tensor.device, dtype=torch.float32)
#         return bin_centers_tensor[bin_index_tensor]


# # --- setup ---
# NUM_BINS = label_mapping["num_bins"]
# BIN_OFFSET = 3
# binner = CoordinateBinner(kappa=1.0, num_bins=NUM_BINS)

# print(f"Loaded label mapping with {len(index_to_label)} mechanism types.")
# print(f"Loaded coupler mapping with {len(coupler_mapping)} entries.")
# print(f"Coordinate binning: {NUM_BINS} bins, BIN_OFFSET={BIN_OFFSET}")
# print("Started")
# start_time = time.time()

# # --- Tokens ---
# sos_token = 0
# eos_token = 1
# pad_token = 2


# # ===================================
# # SINGLE-SAMPLE INFERENCE
# # ===================================
# def predict_single(model, image, label, max_seq_len, device, top_k=10, temperature=1.0, use_top_k=False):
#     """
#     Autoregressive decoding for a single image and label (one at a time).
#     """
#     model.eval()
#     with torch.no_grad():
#         image = image.unsqueeze(0).to(device)
#         label = label.unsqueeze(0).to(device)
#         decoder_input = torch.full((1, 1), sos_token, device=device, dtype=torch.long)
#         pred_seq = []

#         for _ in range(max_seq_len):
#             seq_len = decoder_input.shape[1]
#             causal_mask = (torch.triu(torch.ones(seq_len, seq_len, device=device)) == 1).transpose(0, 1)
#             preds, _, _ = model(decoder_input, causal_mask, image, label)
#             next_logits = preds[:, -1, :] / temperature
#             probs = F.softmax(next_logits, dim=-1)

#             if use_top_k:
#                 topk_probs, topk_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]), dim=-1)
#                 next_token_rel = torch.multinomial(topk_probs, num_samples=1)
#                 next_token = topk_indices.gather(-1, next_token_rel).squeeze(1)
#             else:
#                 next_token = torch.argmax(probs, dim=-1)

#             token_id = int(next_token.item())
#             pred_seq.append(token_id)
#             decoder_input = torch.cat([decoder_input, next_token.unsqueeze(1)], dim=1)
#             if token_id == eos_token:
#                 break

#     return np.array(pred_seq)


# # ===================================
# # RUN INFERENCE ONE-BY-ONE
# # ===================================
# device = next(model.parameters()).device
# max_samples = 100
# predictions, targets, label_indices = [], [], []
# samples_processed = 0

# for batch in tqdm(dataloader, desc="Running one-by-one inference"):
#     if samples_processed >= max_samples:
#         break

#     images = batch["images"]
#     labels = batch["encoded_labels"]
#     targets_discrete = batch["labels_discrete"]

#     for i in range(images.shape[0]):
#         if samples_processed >= max_samples:
#             break

#         image = images[i]
#         label = labels[i]
#         target_seq = targets_discrete[i].cpu().numpy()
#         target_seq = target_seq[target_seq != pad_token]

#         label_idx = int(label.item()) if label.numel() == 1 else int(torch.argmax(label).item())
#         mech_type = index_to_label[str(label_idx)]

#         max_seq_len = model.model_config["tgt_seq_len"] if hasattr(model, "model_config") else len(target_seq)

#         pred_seq = predict_single(model, image, label, max_seq_len, device, top_k=10, temperature=1.0, use_top_k=False)

#         predictions.append(pred_seq)
#         targets.append(target_seq)
#         label_indices.append(label_idx)
#         samples_processed += 1

# print(f"✅ Processed {samples_processed} samples individually.")


# # ===================================
# # CONVERT BINS → CONTINUOUS COORDINATES
# # ===================================
# def bins_to_continuous(seq, binner, bin_offset):
#     seq = np.array(seq)
#     numeric_mask = seq >= bin_offset
#     seq_numeric = seq[numeric_mask] - bin_offset
#     if len(seq_numeric) == 0:
#         return np.array([])
#     seq_tensor = torch.tensor(seq_numeric, dtype=torch.long)
#     seq_cont = binner.bin_to_value_torch(seq_tensor).cpu().numpy()
#     return seq_cont


# # ===================================
# # SIMULATION LOOP
# # ===================================
# for idx, (pred_seq, target_seq, label_idx) in enumerate(zip(predictions, targets, label_indices)):
#     mech_type = index_to_label[str(label_idx)]

#     if mech_type in coupler_mapping and "c" in coupler_mapping[mech_type]:
#         cvec = coupler_mapping[mech_type]["c"]
#         couplerCurveIndex = cvec.index(1) if 1 in cvec else -1
#     else:
#         couplerCurveIndex = -1

#     pred_cont = bins_to_continuous(pred_seq, binner, BIN_OFFSET)
#     target_cont = bins_to_continuous(target_seq, binner, BIN_OFFSET)

#     if len(pred_cont) % 2 == 1:
#         pred_cont = pred_cont[:-1]
#     if len(target_cont) % 2 == 1:
#         target_cont = target_cont[:-1]

#     if len(pred_cont) == 0 or len(target_cont) == 0:
#         continue

#     pred_joints = pred_cont.reshape(-1, 2)
#     gt_joints = target_cont.reshape(-1, 2)
#     num_joints = gt_joints.shape[0]

#     j_points_gt = [gt_joints[i].tolist() for i in range(num_joints)]
#     j_points_pred = [pred_joints[i].tolist() for i in range(min(num_joints, pred_joints.shape[0]))]

#     # --- ORIGINAL MECHANISM SIMULATION ---
#     exampleData = {
#         "params": j_points_gt,
#         "type": mech_type,
#         "speedScale": speedscale,
#         "steps": steps,
#         "relativeTolerance": 0.1,
#     }

#     try:
#         temp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps([exampleData])).json()
#         time.sleep(0.05)
#     except ValueError:
#         continue

#     if temp[0]["poses"] is None:
#         continue

#     P = np.array(temp[0]["poses"])
#     if P.shape[0] < minsteps:
#         continue

#     original_x, original_y = P[:, couplerCurveIndex, 0], P[:, couplerCurveIndex, 1]
#     original_phi = -get_pca_inclination(original_x, original_y)
#     original_denom = np.sqrt(np.var(original_x) + np.var(original_y))
#     original_mean_x, original_mean_y = np.mean(original_x), np.mean(original_y)

#     # --- PREDICTED MECHANISM SIMULATION ---
#     exampleData["params"] = j_points_pred
#     try:
#         temp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps([exampleData])).json()
#         time.sleep(0.05)
#     except ValueError:
#         continue
#     if temp[0]["poses"] is None:
#         continue
#     P = np.array(temp[0]["poses"])
#     if P.shape[0] < minsteps:
#         continue

#     generated_x, generated_y = P[:, couplerCurveIndex, 0], P[:, couplerCurveIndex, 1]
#     if np.isnan(generated_x).any() or np.isinf(generated_x).any() or len(generated_x) < 30:
#         continue

#     generated_phi = -get_pca_inclination(generated_x, generated_y)
#     rotation = generated_phi - original_phi
#     generated_x, generated_y = rotate_curve(generated_x, generated_y, rotation)

#     generated_denom = np.sqrt(np.var(generated_x) + np.var(generated_y))
#     scale_factor = original_denom / (generated_denom + 1e-8)
#     generated_x *= scale_factor
#     generated_y *= scale_factor

#     generated_mean_x, generated_mean_y = np.mean(generated_x), np.mean(generated_y)
#     translation_x, translation_y = generated_mean_x - original_mean_x, generated_mean_y - original_mean_y
#     generated_x -= translation_x
#     generated_y -= translation_y

#     # --- PLOT BOTH CURVES ---
#     plt.plot(original_x, original_y, "r", label="original")
#     plt.plot(generated_x, generated_y, "g", label="predicted")
#     plt.title(f"Mechanism: {mech_type} | Coupler index: {couplerCurveIndex}")
#     plt.axis("equal")
#     plt.legend()

#     out_dir = f"results/{idx}"
#     os.makedirs(out_dir, exist_ok=True)
#     plt.savefig(f"{out_dir}/{idx}_{mech_type}_iter_pred.jpg")
#     plt.clf()

# print(f"✅ Finished all samples in {time.time() - start_time:.2f} seconds")


In [None]:
# import json
# import torch
# import numpy as np
# from tqdm import tqdm
# import matplotlib.pyplot as plt
# import requests
# import time
# import os
# from curve_plot import get_pca_inclination, rotate_curve

# # ===================================
# # CONFIGURATION
# # ===================================
# label_mapping_path = "/home/anurizada/Documents/processed_dataset_17/label_mapping.json"
# with open(label_mapping_path, "r") as f:
#     label_mapping = json.load(f)
# index_to_label = label_mapping["index_to_label"]

# # --- coordinate binning setup ---
# class CoordinateBinner:
#     def __init__(self, kappa=1.0, num_bins=200):
#         self.kappa = kappa
#         self.num_bins = num_bins
#         self.bin_edges = np.linspace(-kappa, kappa, num_bins + 1)
#         self.bin_centers = (self.bin_edges[:-1] + self.bin_edges[1:]) / 2

#     def bin_to_value_torch(self, bin_index_tensor):
#         bin_index_tensor = torch.clamp(bin_index_tensor, 0, self.num_bins - 1)
#         bin_centers_tensor = torch.tensor(self.bin_centers, device=bin_index_tensor.device, dtype=torch.float32)
#         return bin_centers_tensor[bin_index_tensor]

# # from your label_mapping.json
# NUM_BINS = label_mapping["num_bins"]
# BIN_OFFSET = 3  # usually 3
# binner = CoordinateBinner(kappa=1.0, num_bins=NUM_BINS)

# print(f"Loaded label mapping with {len(index_to_label)} mechanism types.")
# print(f"Coordinate binning: {NUM_BINS} bins, BIN_OFFSET={BIN_OFFSET}")

# print('Started')
# start_time = time.time()

# eos_token = 1
# pad_token = 2

# # ===================================
# # BATCH INFERENCE
# # ===================================
# def predict_batch(model, dataloader, max_samples=100, device="cuda"):
#     all_predictions, all_targets, all_labels = [], [], []
#     samples_processed = 0

#     with torch.no_grad():
#         for batch in tqdm(dataloader, desc="Running batch inference"):
#             if samples_processed >= max_samples:
#                 break

#             decoder_input = batch["decoder_input_discrete"].to(device)
#             decoder_mask = batch["causal_mask"].to(device)
#             images = batch["images"].to(device)
#             encoded_labels = batch["encoded_labels"].to(device)
#             target_tokens = batch["labels_discrete"].to(device)

#             predictions, _, _ = model(decoder_input, decoder_mask, images, encoded_labels)
#             pred_tokens = predictions.argmax(dim=-1)

#             for i in range(pred_tokens.shape[0]):
#                 if samples_processed >= max_samples:
#                     break

#                 pred_seq = pred_tokens[i].cpu().numpy()
#                 target_seq = target_tokens[i].cpu().numpy()

#                 valid_mask = target_seq != pad_token
#                 pred_seq = pred_seq[valid_mask]
#                 target_seq = target_seq[valid_mask]

#                 if eos_token in pred_seq:
#                     pred_seq = pred_seq[: np.where(pred_seq == eos_token)[0][0]]
#                 if eos_token in target_seq:
#                     target_seq = target_seq[: np.where(target_seq == eos_token)[0][0]]

#                 # get label index from one-hot or already-int encoded tensor
#                 label_idx = encoded_labels[i].item()

#                 all_predictions.append(pred_seq)
#                 all_targets.append(target_seq)
#                 all_labels.append(label_idx)
#                 samples_processed += 1

#     print(f"\nProcessed {samples_processed} samples total")
#     return all_predictions, all_targets, all_labels


# # ===================================
# # RUN INFERENCE
# # ===================================
# max_samples = 20
# predictions, targets, label_indices = predict_batch(model, dataloader, max_samples=max_samples, device=device)
# print(label_indices)

# # ===================================
# # CONVERT BINS → CONTINUOUS COORDINATES
# # ===================================
# def bins_to_continuous(seq, binner, bin_offset):
#     seq = np.array(seq)
#     # remove special tokens (anything below BIN_OFFSET)
#     numeric_mask = seq >= bin_offset
#     seq_numeric = seq[numeric_mask] - bin_offset
#     seq_tensor = torch.tensor(seq_numeric, dtype=torch.long)
#     seq_cont = binner.bin_to_value_torch(seq_tensor).cpu().numpy()
#     return seq_cont


# # ===================================
# # SIMULATION LOOP
# # ===================================
# for idx, (pred_seq, target_seq, label_idx) in enumerate(zip(predictions, targets, label_indices)):
#     mech_type = index_to_label[str(label_idx)]

#     # Convert discrete bins → continuous coords
#     pred_cont = bins_to_continuous(pred_seq, binner, BIN_OFFSET)
#     target_cont = bins_to_continuous(target_seq, binner, BIN_OFFSET)

#     # Drop odd lengths to form (N, 2)
#     if len(pred_cont) % 2 == 1:
#         pred_cont = pred_cont[:-1]
#     if len(target_cont) % 2 == 1:
#         target_cont = target_cont[:-1]

#     pred_joints = pred_cont.reshape(-1, 2)
#     gt_joints = target_cont.reshape(-1, 2)
#     num_joints = gt_joints.shape[0]

#     j_points_gt = [gt_joints[i].tolist() for i in range(num_joints)]
#     j_points_pred = [pred_joints[i].tolist() for i in range(min(num_joints, pred_joints.shape[0]))]
#     couplerCurveIndex = num_joints - 1  # last joint as coupler

#     # --- ORIGINAL MECHANISM SIMULATION ---
#     exampleData = {
#         "params": j_points_gt,
#         "type": mech_type,
#         "speedScale": speedscale,
#         "steps": steps,
#         "relativeTolerance": 0.1,
#     }

#     try:
#         temp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps([exampleData])).json()
#         time.sleep(0.05)
#     except ValueError:
#         continue

#     if temp[0]["poses"] is None:
#         continue

#     P = np.array(temp[0]["poses"])
#     if P.shape[0] < minsteps:
#         continue

#     original_x, original_y = P[:, couplerCurveIndex, 0], P[:, couplerCurveIndex, 1]
#     original_mean_x, original_mean_y = np.mean(original_x), np.mean(original_y)
#     original_denom = np.sqrt(np.var(original_x) + np.var(original_y))
#     original_phi = -get_pca_inclination(original_x, original_y)

#     # --- PREDICTED MECHANISM SIMULATION ---
    
#     exampleData["params"] = j_points_pred
#     try:
#         temp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps([exampleData])).json()
#         time.sleep(0.05)
#     except ValueError:
#         continue

#     if temp[0]["poses"] is None:
#         continue

#     P = np.array(temp[0]["poses"])
#     if P.shape[0] < minsteps:
#         continue

#     generated_x, generated_y = P[:, couplerCurveIndex, 0], P[:, couplerCurveIndex, 1]
#     if np.isnan(generated_x).any() or np.isinf(generated_x).any() or len(generated_x) < 30:
#         continue

#     # --- ALIGN CURVES ---
#     generated_phi = -get_pca_inclination(generated_x, generated_y)
#     rotation = generated_phi - original_phi
#     generated_x, generated_y = rotate_curve(generated_x, generated_y, rotation)

#     generated_denom = np.sqrt(np.var(generated_x) + np.var(generated_y))
#     scale_factor = original_denom / generated_denom
#     generated_x, generated_y = np.multiply(generated_x, scale_factor), np.multiply(generated_y, scale_factor)

#     generated_mean_x, generated_mean_y = np.mean(generated_x), np.mean(generated_y)
#     translation_x, translation_y = generated_mean_x - original_mean_x, generated_mean_y - original_mean_y
#     generated_x, generated_y = np.subtract(generated_x, translation_x), np.subtract(generated_y, translation_y)

#     # --- PLOT BOTH CURVES ---
#     plt.plot(original_x, original_y, "r", label="original")
#     plt.plot(generated_x, generated_y, "g", label="predicted")
#     plt.title(f"Mechanism: {mech_type}")
#     plt.axis("equal")
#     plt.legend()

#     out_dir = f"results/{idx}"
#     os.makedirs(out_dir, exist_ok=True)
#     plt.savefig(f"{out_dir}/{idx}_{mech_type}_batch_pred.jpg")
#     plt.clf()

# print(f"Finished in {time.time() - start_time:.2f} seconds")


In [None]:
# !rm results.zip
# !zip -r results.zip results