In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
from model import SingleImageTransformer
from dataset import BarLinkageDataset 
import matplotlib.pyplot as plt
torch.set_float32_matmul_precision('medium')

from dataset_generation.curve_plot import get_pca_inclination, rotate_curve
import scipy.spatial.distance as sciDist
from tqdm import tqdm
import requests
import time
import matplotlib.pyplot as plt
import os
import json
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Headless simulator version
index = 0 # local server index 
API_ENDPOINT = f"http://localhost:4000/simulation"
HEADERS = {"Content-Type": "application/json"}
speedscale = 1
steps = 360
minsteps = int(steps*20/360)

In [None]:
checkpoint_path = "weights/transformer_weights_CE/d1024_h32_n6_bs512_lr0.0001_best.pth"
data_dir = "/home/anurizada/Documents/processed_dataset"
batch_size = 1

dataset = BarLinkageDataset(data_dir=data_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
checkpoint = torch.load(checkpoint_path, map_location=device)
model_config = checkpoint['model_config']

# Initialize model
model = SingleImageTransformer(
    tgt_seq_len=model_config['tgt_seq_len'],
    d_model=model_config['d_model'],
    h=model_config['h'],
    N=model_config['N'],
    num_labels=model_config['num_labels'],
    vocab_size=model_config['vocab_size'],
).to(device)

# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

In [None]:
# ===============================
# CONFIG
# ===============================
MAX_SAMPLES = 5
TEMPERATURES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 1.0, 1.3, 1.5, 1.7, 2.0]

SOS_TOKEN, EOS_TOKEN, PAD_TOKEN = 0, 1, 2
NUM_SPECIAL_TOKENS = 3
NUM_MECH_TYPES = 17
BIN_OFFSET = NUM_SPECIAL_TOKENS + NUM_MECH_TYPES  # = 20

# Define externally:
# API_ENDPOINT, HEADERS, speedscale, steps, minsteps

label_mapping_path = "/home/anurizada/Documents/processed_dataset/label_mapping.json"
coupler_mapping_path = "/home/anurizada/Documents/transformer/BSIdict.json"

with open(label_mapping_path, "r") as f:
    label_mapping = json.load(f)
index_to_label = label_mapping["index_to_label"]

with open(coupler_mapping_path, "r") as f:
    coupler_mapping = json.load(f)

def coupler_index_for(mech_type: str) -> int:
    """Return coupler curve index from BSIdict.json."""
    if mech_type in coupler_mapping and "c" in coupler_mapping[mech_type]:
        cvec = coupler_mapping[mech_type]["c"]
        if isinstance(cvec, list) and 1 in cvec:
            return cvec.index(1)
    return -1

# ===============================
# BINNER
# ===============================
class CoordinateBinner:
    def __init__(self, kappa=1.0, num_bins=200):
        self.kappa = kappa
        self.num_bins = num_bins
        self.bin_edges = np.linspace(-kappa, kappa, num_bins + 1)
        self.bin_centers = (self.bin_edges[:-1] + self.bin_edges[1:]) / 2

    def bin_to_value_torch(self, bin_index_tensor):
        bin_index_tensor = torch.clamp(bin_index_tensor, 0, self.num_bins - 1)
        device = bin_index_tensor.device
        bin_centers_tensor = torch.tensor(self.bin_centers, device=device, dtype=torch.float32)
        return bin_centers_tensor[bin_index_tensor]

binner = CoordinateBinner(kappa=1.0, num_bins=201)

# ===============================
# AUTOREGRESSIVE PREDICTION
# ===============================
def predict_autoregressive(model, image, max_seq_len, device, temperature=1.0, top_k=None):
    model.eval()
    with torch.no_grad():
        decoder_input = torch.tensor([[SOS_TOKEN]], device=device, dtype=torch.long)
        for _ in range(max_seq_len):
            T = decoder_input.size(1)
            causal_mask = torch.tril(torch.ones(T, T, device=device)).bool()

            logits = model(decoder_input, causal_mask, image, None)
            next_logits = logits[:, -1, :] / float(temperature)
            probs = F.softmax(next_logits, dim=-1)

            if top_k is None:
                next_token = torch.argmax(probs, dim=-1)
            else:
                k = min(int(top_k), probs.size(-1))
                topk_probs, topk_idx = torch.topk(probs, k=k, dim=-1)
                sampled = torch.multinomial(topk_probs, num_samples=1)
                next_token = topk_idx.gather(-1, sampled).squeeze(1)

            token = int(next_token.item())
            decoder_input = torch.cat([decoder_input, next_token.unsqueeze(1)], dim=1)
            if token == EOS_TOKEN:
                break
    return decoder_input.squeeze(0).cpu().numpy()

# ===============================
# MAIN LOOP
# ===============================
print("Starting multi-temperature coupler curve generation...")
os.makedirs("results_coupler", exist_ok=True)

for i, batch in enumerate(tqdm(dataloader, total=MAX_SAMPLES, desc="Simulating")):
    if i >= MAX_SAMPLES:
        break

    image = batch["images"].to(device)
    gt_tokens = batch["labels_discrete"][0].numpy()

    # --- Ground Truth ---
    gt_mech_idx = gt_tokens[0] - NUM_SPECIAL_TOKENS
    gt_mech_name = index_to_label.get(str(gt_mech_idx), "UNKNOWN")
    gt_coord_tokens = [t for t in gt_tokens[1:] if t >= BIN_OFFSET]
    if len(gt_coord_tokens) < 4:
        continue
    gt_coords_tensor = torch.tensor(gt_coord_tokens) - BIN_OFFSET
    gt_coords_float = binner.bin_to_value_torch(gt_coords_tensor).cpu().numpy()
    if gt_coords_float.size % 2 == 1:
        gt_coords_float = gt_coords_float[:-1]
    gt_points = gt_coords_float.reshape(-1, 2)

    # Create subfolder for this sample
    sample_dir = f"results_coupler/sample_{i:03d}_{gt_mech_name}"
    os.makedirs(sample_dir, exist_ok=True)

    # --- Simulate GT coupler ---
    ex_gt = {
        "params": gt_points.tolist(),
        "type": gt_mech_name,
        "speedScale": speedscale,
        "steps": steps,
        "relativeTolerance": 0.1,
    }
    try:
        temp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps([ex_gt])).json()
        P = np.array(temp[0]["poses"]) if isinstance(temp, list) and temp and "poses" in temp[0] else None
    except Exception as e:
        print(f"GT sim failed: {e}")
        continue
    if P is None or P.shape[0] < minsteps:
        continue

    coup_idx_gt = coupler_index_for(gt_mech_name)
    if coup_idx_gt < 0:
        continue
    original_x, original_y = P[:, coup_idx_gt, 0], P[:, coup_idx_gt, 1]
    orig_phi = -get_pca_inclination(original_x, original_y)
    orig_denom = np.sqrt(np.var(original_x) + np.var(original_y)) + 1e-8
    ox_mean, oy_mean = np.mean(original_x), np.mean(original_y)

    print(f"\n=== Sample {i} | GT={gt_mech_name} ===")
    print("GT joint positions:")
    for j, (x, y) in enumerate(gt_points):
        print(f"  J{j}: ({x:.3f}, {y:.3f})")

    # --- Multiple Predictions (different temperatures) ---
    for temp_val in TEMPERATURES:
        pred_tokens = predict_autoregressive(model, image, model_config["tgt_seq_len"], device, temperature=temp_val, top_k=32)
        mech_idx = pred_tokens[1] - NUM_SPECIAL_TOKENS if len(pred_tokens) > 1 else -1
        mech_name = index_to_label.get(str(mech_idx), "UNKNOWN")

        coord_tokens = [t for t in pred_tokens if t >= BIN_OFFSET]
        if len(coord_tokens) < 4:
            continue
        coords_tensor = torch.tensor(coord_tokens) - BIN_OFFSET
        coords_float = binner.bin_to_value_torch(coords_tensor).cpu().numpy()
        if coords_float.size % 2 == 1:
            coords_float = coords_float[:-1]
        pred_points = coords_float.reshape(-1, 2)

        print(f"\n[Temp={temp_val:.1f}] Pred mech: {mech_name}")
        print("Pred joint positions:")
        for j, (x, y) in enumerate(pred_points):
            print(f"  J{j}: ({x:.3f}, {y:.3f})")

        # --- Simulate predicted coupler ---
        ex_pred = {
            "params": pred_points.tolist(),
            "type": mech_name,
            "speedScale": speedscale,
            "steps": steps,
            "relativeTolerance": 0.1,
        }
        try:
            temp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps([ex_pred])).json()
            Pp = np.array(temp[0]["poses"]) if isinstance(temp, list) and temp and "poses" in temp[0] else None
        except Exception as e:
            print(f"Pred sim failed (temp={temp_val}): {e}")
            continue
        print(Pp)
        
        if Pp is None:
            continue

        if P.shape[0] < minsteps:
            continue

        coup_idx_pred = coupler_index_for(mech_name)
        if coup_idx_pred < 0:
            continue
        generated_x, generated_y = Pp[:, coup_idx_pred, 0], Pp[:, coup_idx_pred, 1]
        if np.isnan(generated_x).any() or np.isinf(generated_x).any():
            continue

        # --- Align predicted to GT ---
        gen_phi = -get_pca_inclination(generated_x, generated_y)
        rotation = gen_phi - orig_phi
        generated_x, generated_y = rotate_curve(generated_x, generated_y, rotation)
        gen_denom = np.sqrt(np.var(generated_x) + np.var(generated_y)) + 1e-8
        scale = orig_denom / gen_denom
        generated_x *= scale
        generated_y *= scale
        gx_mean, gy_mean = np.mean(generated_x), np.mean(generated_y)
        generated_x -= (gx_mean - ox_mean)
        generated_y -= (gy_mean - oy_mean)

        # --- Plot ---
        plt.figure(figsize=(6, 6))
        plt.plot(original_x, original_y, "r-", label=f"GT Coupler ({gt_mech_name})")
        plt.plot(generated_x, generated_y, "g--", label=f"Pred Coupler ({mech_name})")

        # Plot joint points
        plt.scatter(gt_points[:, 0], gt_points[:, 1], color="red", s=40, zorder=5)
        for j, (x, y) in enumerate(gt_points):
            plt.text(x, y, f"J{j}", color="red", fontsize=9, weight="bold")

        plt.scatter(pred_points[:, 0], pred_points[:, 1], color="green", s=40, zorder=5)
        for j, (x, y) in enumerate(pred_points):
            plt.text(x, y, f"J{j}", color="green", fontsize=9, weight="bold")

        plt.axis("equal")
        plt.legend()
        plt.title(f"Sample {i} | GT={gt_mech_name} | Pred={mech_name} | Temp={temp_val}")
        plt.tight_layout()
        save_path = os.path.join(sample_dir, f"temp_{temp_val:.1f}.png")
        plt.savefig(save_path)
        plt.close()
        print(f"✅ Saved: {save_path}")

print("✅ Finished all temperature variations.")
