In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
from llama_latent_model import LatentLLaMA_SingleToken
from dataset import BarLinkageDataset 
import matplotlib.pyplot as plt
torch.set_float32_matmul_precision('medium')

from dataset_generation.curve_plot import get_pca_inclination, rotate_curve
import scipy.spatial.distance as sciDist
from tqdm import tqdm
import requests
import time
import matplotlib.pyplot as plt
import os
import json
import torch.nn.functional as F

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
# Headless simulator version
index = 0 # local server index 
API_ENDPOINT = f"http://localhost:4000/simulation"
HEADERS = {"Content-Type": "application/json"}
speedscale = 1
steps = 360
minsteps = int(steps*20/360)

In [None]:
checkpoint_path = "./weights/CE_GAUS_MSE/LATENT_LLAMA_d1024_h32_n6_bs512_lr0.0005_best.pth"
data_dir = "/home/anurizada/Documents/processed_dataset_17"
batch_size = 1

dataset = BarLinkageDataset(data_dir=data_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
checkpoint = torch.load(checkpoint_path, map_location=device)
model_config = checkpoint['model_config']

# Initialize model
model = LatentLLaMA_SingleToken(
    tgt_seq_len=model_config['tgt_seq_len'],
    d_model=model_config['d_model'],
    h=model_config['h'],
    N=model_config['N'],
    num_labels=model_config['num_labels'],
    vocab_size=model_config['vocab_size'],
    latent_dim=model_config['latent_dim']).to(device)

# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# ---------------------------
# Count parameters
# ---------------------------
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params

    print("\nðŸ§® Model Parameter Summary")
    print(f"Total parameters:     {total_params:,}  ({total_params/1e6:.2f} M)")
    print(f"Trainable parameters: {trainable_params:,}  ({trainable_params/1e6:.2f} M)")
    print(f"Frozen parameters:    {non_trainable_params:,}  ({non_trainable_params/1e6:.2f} M)")
    return total_params, trainable_params, non_trainable_params

# Run the counter
count_parameters(model)

In [None]:
# =========================================================
# CONFIG
# =========================================================
MAX_SAMPLES = 100
SOS_TOKEN, EOS_TOKEN, PAD_TOKEN = 0, 1, 2
NUM_SPECIAL_TOKENS = 3
NUM_MECH_TYPES = 17
BIN_OFFSET = NUM_SPECIAL_TOKENS
NUM_BINS = 201            # must match training
LATENT_DIM = 50           # must match training

label_mapping_path = "/home/anurizada/Documents/processed_dataset_17/label_mapping.json"
coupler_mapping_path = "/home/anurizada/Documents/transformer/BSIdict.json"

with open(label_mapping_path, "r") as f:
    label_mapping = json.load(f)
index_to_label = label_mapping["index_to_label"]

with open(coupler_mapping_path, "r") as f:
    coupler_mapping = json.load(f)


def coupler_index_for(mech_type: str) -> int:
    """Return coupler curve index from BSIdict.json."""
    if mech_type in coupler_mapping and "c" in coupler_mapping[mech_type]:
        cvec = coupler_mapping[mech_type]["c"]
        if isinstance(cvec, list) and 1 in cvec:
            return cvec.index(1)
    return -1


# ---------------------------------------------------------
# Helper: safe + short names for filesystem paths
# ---------------------------------------------------------
def safe_name(name: str, max_len: int = 30) -> str:
    chars = []
    for c in name:
        if c.isalnum():
            chars.append(c)
        else:
            chars.append("_")
    sanitized = "".join(chars)
    if len(sanitized) > max_len:
        sanitized = sanitized[:max_len]
    return sanitized or "unk"


def temp_to_str(t: float) -> str:
    s = f"{t:.2f}".rstrip("0").rstrip(".")
    s = s.replace(".", "p").replace("-", "m")
    return s or "0"


# =========================================================
# CoordinateBinner (same as training)
# =========================================================
class CoordinateBinner:
    def __init__(self, kappa=1.0, num_bins=200):
        self.kappa = kappa
        self.num_bins = num_bins
        self.bin_edges = np.linspace(-kappa, kappa, num_bins + 1)
        self.bin_centers = (self.bin_edges[:-1] + self.bin_edges[1:]) / 2

    def bin_to_value_torch(self, bin_index_tensor):
        bin_index_tensor = torch.clamp(bin_index_tensor, 0, self.num_bins - 1)
        bin_centers_tensor = torch.tensor(
            self.bin_centers, device=bin_index_tensor.device, dtype=torch.float32
        )
        return bin_centers_tensor[bin_index_tensor]


binner = CoordinateBinner(kappa=1.0, num_bins=NUM_BINS - 1)


# =========================================================
# Causal mask builder
# =========================================================
def build_causal_mask(seq_len: int, device: torch.device):
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool))
    return mask.unsqueeze(0).unsqueeze(0)


# =========================================================
# AUTOREGRESSIVE PREDICTION (Latent LLaMA)
# =========================================================
def predict_autoregressive_latent(
    model,
    latent,
    mech_idx: int,
    max_seq_len: int,
    device,
    temperature: float = 1.0,
    top_k: int | None = None,
    eos_token: int = EOS_TOKEN,
    sos_token: int = SOS_TOKEN,
):
    model.eval()

    if latent.dim() == 1:
        latent = latent.unsqueeze(0)

    latent = latent.to(device)
    mech_labels = torch.tensor([mech_idx], device=device, dtype=torch.long)

    with torch.no_grad():
        decoder_input = torch.tensor([[sos_token]], device=device, dtype=torch.long)

        for step in range(max_seq_len):
            seq_len = decoder_input.size(1)
            causal_mask = build_causal_mask(seq_len, device)

            logits = model(
                decoder_input,
                causal_mask,
                latent,
                mech_labels,
            )

            next_logits = logits[:, -1, :] / max(temperature, 1e-6)
            probs = F.softmax(next_logits, dim=-1)

            # ----------------------------------------
            # TOP-K sampling
            # ----------------------------------------
            if top_k is not None and top_k > 0:
                k = min(int(top_k), probs.size(-1))
                topk_probs, topk_idx = torch.topk(probs, k=k, dim=-1)
                next_token = topk_idx.gather(-1, torch.multinomial(topk_probs, 1))

            elif temperature == 0:
                next_token = torch.argmax(probs, dim=-1, keepdim=True)

            else:
                next_token = torch.multinomial(probs, num_samples=1)

            token = int(next_token.item())
            decoder_input = torch.cat([decoder_input, next_token], dim=1)

            if token == eos_token:
                break

    return decoder_input.squeeze(0).cpu().numpy()


# =========================================================
# MAIN LOOP
# =========================================================

tgt_seq_len = model_config["tgt_seq_len"]

print("Starting conditional coupler curve generation (latent-based)...")
os.makedirs("results_coupler_latent", exist_ok=True)

for i, batch in enumerate(tqdm(dataloader, total=min(MAX_SAMPLES, len(dataset)), desc="Simulating")):
    if i >= MAX_SAMPLES:
        break

    # --- Latent input ---
    latents = batch["vae_mu"].to(device).squeeze(-1)
    latent = latents[0]

    # --- GT ---
    gt_tokens = batch["labels_discrete"][0].numpy()
    gt_mech_idx = int(batch["encoded_labels"][0].item())
    gt_mech_name = index_to_label.get(str(gt_mech_idx), "UNKNOWN")
    gt_mech_name_safe = safe_name(gt_mech_name)

    # --- Save input image ---
    sample_dir = f"results_coupler_latent/sample_{i:03d}_{gt_mech_name_safe}"
    os.makedirs(sample_dir, exist_ok=True)

    if "images" in batch:
        img_np = batch["images"][0].detach().cpu().squeeze().numpy()
        plt.imsave(os.path.join(sample_dir, "input_image.png"), img_np, cmap="gray")

    # --- Ground Truth coords ---
    gt_coord_tokens = [t for t in gt_tokens if t >= BIN_OFFSET]
    if len(gt_coord_tokens) < 4:
        continue

    gt_coords_tensor = torch.tensor(gt_coord_tokens) - BIN_OFFSET
    gt_coords_float = binner.bin_to_value_torch(gt_coords_tensor.to(device)).cpu().numpy()

    if gt_coords_float.size % 2 == 1:
        gt_coords_float = gt_coords_float[:-1]
    gt_points = gt_coords_float.reshape(-1, 2)

    # --- Simulate GT ---
    ex_gt = {
        "params": gt_points.tolist(),
        "type": gt_mech_name,
        "speedScale": speedscale,
        "steps": steps,
        "relativeTolerance": 0.1,
    }
    try:
        temp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps([ex_gt])).json()
        P = np.array(temp[0]["poses"]) if isinstance(temp, list) and temp and "poses" in temp[0] else None
    except:
        continue

    if P is None or P.shape[0] < minsteps:
        continue

    coup_idx_gt = coupler_index_for(gt_mech_name)
    if coup_idx_gt < 0:
        continue

    original_x, original_y = P[:, coup_idx_gt, 0], P[:, coup_idx_gt, 1]
    orig_phi = -get_pca_inclination(original_x, original_y)
    orig_denom = np.sqrt(np.var(original_x) + np.var(original_y)) + 1e-8
    ox_mean, oy_mean = np.mean(original_x), np.mean(original_y)

    all_predicted_points: dict[str, np.ndarray] = {}

    # ============================================================
    # Temperature Ã— Top-K Sweep
    # ============================================================

    temperatures = [0.0, 0.5, 1.0, 1.5, 2.0]
    top_k_values = [1, 5, 10, 20]

    for mech_idx in range(NUM_MECH_TYPES):

        mech_name = index_to_label.get(str(mech_idx), f"mech_{mech_idx}")
        mech_name_safe = safe_name(mech_name)

        for temp in temperatures:
            temp_str = temp_to_str(temp)

            for top_k in top_k_values:

                pred_tokens = predict_autoregressive_latent(
                    model=model,
                    latent=latent,
                    mech_idx=mech_idx,
                    max_seq_len=tgt_seq_len,
                    device=device,
                    temperature=temp,
                    top_k=top_k,        # <--- TOP-K ENABLED
                )

                coord_tokens = [t for t in pred_tokens if t >= BIN_OFFSET]
                if len(coord_tokens) < 4:
                    continue

                coords_tensor = torch.tensor(coord_tokens, device=device) - BIN_OFFSET
                coords_float = binner.bin_to_value_torch(coords_tensor).cpu().numpy()

                if coords_float.size % 2 == 1:
                    coords_float = coords_float[:-1]

                pred_points = coords_float.reshape(-1, 2)
                key = f"{mech_name_safe}_t{temp_str}_k{top_k}"
                all_predicted_points[key] = pred_points

                # --- Simulate predicted ---
                ex_pred = {
                    "params": pred_points.tolist(),
                    "type": mech_name,
                    "speedScale": speedscale,
                    "steps": steps,
                    "relativeTolerance": 0.1,
                }
                try:
                    temp_resp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps([ex_pred])).json()
                    Pp = np.array(temp_resp[0]["poses"]) if isinstance(temp_resp, list) and temp_resp and "poses" in temp_resp[0] else None
                except:
                    continue

                if Pp is None or Pp.shape[0] < minsteps:
                    continue

                coup_idx_pred = coupler_index_for(mech_name)
                if coup_idx_pred < 0:
                    continue

                generated_x, generated_y = Pp[:, coup_idx_pred, 0], Pp[:, coup_idx_pred, 1]
                if np.isnan(generated_x).any() or np.isinf(generated_x).any():
                    continue

                # --- Align predicted coupler ---
                gen_phi = -get_pca_inclination(generated_x, generated_y)
                rotation = gen_phi - orig_phi
                generated_x, generated_y = rotate_curve(generated_x, generated_y, rotation)

                gen_denom = np.sqrt(np.var(generated_x) + np.var(generated_y)) + 1e-8
                scale = orig_denom / gen_denom
                generated_x *= scale
                generated_y *= scale

                gx_mean, gy_mean = np.mean(generated_x), np.mean(generated_y)
                generated_x -= (gx_mean - ox_mean)
                generated_y -= (gy_mean - oy_mean)

                # --- Plot ---
                plt.figure(figsize=(6, 6))
                plt.plot(original_x, original_y, "r-", label=f"GT Coupler ({gt_mech_name})")
                plt.plot(generated_x, generated_y, "g--", label=f"Pred ({mech_name}) t={temp} k={top_k}")
                plt.scatter(gt_points[:, 0], gt_points[:, 1], color="red", s=40)
                plt.scatter(pred_points[:, 0], pred_points[:, 1], color="green", s=40)
                plt.axis("equal")
                plt.legend()
                plt.title(f"Sample {i} | GT={gt_mech_name} | Pred={mech_name} | t={temp} k={top_k}")
                plt.tight_layout()

                save_fname = f"mech_{mech_name_safe}_t{temp_str}_k{top_k}.png"
                save_path = os.path.join(sample_dir, save_fname)
                plt.savefig(save_path)
                plt.close()

    # ============================================================
    # Combined joint prediction scatter plot
    # ============================================================
    if len(all_predicted_points) == 0:
        continue

    plt.figure(figsize=(8, 8))
    plt.scatter(
        gt_points[:, 0],
        gt_points[:, 1],
        c="red",
        s=80,
        edgecolor="black",
        label=f"GT ({gt_mech_name})",
        zorder=6,
    )

    for j, (x, y) in enumerate(gt_points):
        plt.text(x + 0.005, y + 0.005, f"{j}", color="red", fontsize=9)

    max_joints = max(pts.shape[0] for pts in all_predicted_points.values())
    cmap = plt.cm.get_cmap("tab10", max_joints)

    for mech_name_key, pts in all_predicted_points.items():
        num_joints = pts.shape[0]
        for j in range(num_joints):
            color = cmap(j % max_joints)
            plt.scatter(pts[j, 0], pts[j, 1], color=color, s=30, alpha=0.8)
            plt.text(pts[j, 0] + 0.002, pts[j, 1] + 0.002, str(j), fontsize=7, color=color)

    plt.title(f"Predicted Joint Positions â€” Sample {i} (GT={gt_mech_name})")
    plt.xlabel("X coordinate")
    plt.ylabel("Y coordinate")
    plt.axis("equal")

    handles = [
        plt.Line2D([0], [0], marker="o", color="w", markerfacecolor=cmap(j), label=f"Joint {j}")
        for j in range(max_joints)
    ]
    plt.legend(handles=handles, fontsize=7, loc="upper right", ncol=2)

    plt.tight_layout()
    joint_scatter_path = os.path.join(sample_dir, "all_predicted_joints_colored.png")
    plt.savefig(joint_scatter_path, dpi=200)
    plt.close()

    print(f"âœ… Saved: {joint_scatter_path}")

print("âœ… Finished all mechanism variations (latent-based).")
