# CLIP4CAD-GFA v4.2 Evaluation

## Metrics
1. **Self-grounding quality**: cosine(z_guided, z_self) - target > 0.85
2. **Query alignment**: cosine(T_feat, Q_self) - target > 0.7
3. **Retrieval**: Text→BRep, Text→PC R@1, R@5, R@10
4. **Self-path gap**: guided R@1 - self R@1 - target < 10%

In [None]:
# Cell 1: Imports and Setup
import sys
sys.path.insert(0, '..')

import os
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from pathlib import Path

from clip4cad.models import CLIP4CAD_GFA_v4_2, GFAv4_2Config
from clip4cad.losses.gfa_v4_2_losses import compute_self_grounding_quality, compute_query_alignment

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")

In [None]:
# Cell 2: Data Paths

DATA_ROOT = Path("d:/Defect_Det/MMCAD/data")
PC_FILE = Path("c:/Users/User/Desktop/pc_embeddings_full.h5")
BREP_FILE = Path("c:/Users/User/Desktop/brep_features.h5")
TEXT_FILE = Path("c:/Users/User/Desktop/text_embeddings.h5")
MODEL_PATH = Path("../outputs/gfa_v4_2/clip4cad_gfa_v4_2_final.pt")

print(f"Data root: {DATA_ROOT}")
print(f"Model: {MODEL_PATH} (exists: {MODEL_PATH.exists()})")

In [None]:
# Cell 3: Load Model

checkpoint = torch.load(MODEL_PATH, map_location=device)
config = GFAv4_2Config(**checkpoint['config'])

model = CLIP4CAD_GFA_v4_2(config).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Set to fully independent mode (no hints)
model.set_cond_dropout(1.0)

print(f"Loaded model with {model.count_parameters():,} parameters")
print(f"Conditioning dropout: 1.0 (fully independent)")

In [None]:
# Cell 4: Load Validation Data

from clip4cad.data.gfa_dataset import GFAMappedDataset, gfa_collate_fn

print("Loading validation data...")
val_dataset = GFAMappedDataset(
    data_root=str(DATA_ROOT),
    split="val",
    pc_file=str(PC_FILE),
    text_file=str(TEXT_FILE),
    brep_file=str(BREP_FILE),
    num_rotations=1,
    load_to_memory=False,
    use_live_text=False,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=0,
    collate_fn=gfa_collate_fn,
    pin_memory=True
)

print(f"Validation samples: {len(val_dataset)}")

In [None]:
# Cell 5: Collect Embeddings

all_z_brep_guided = []
all_z_pc_guided = []
all_z_brep_self = []
all_z_pc_self = []
all_z_text = []
all_T_feat = []
all_Q_brep_self = []
all_Q_pc_self = []
all_confidence = []
all_uids = []

print("Extracting embeddings...")
with torch.no_grad():
    for batch in tqdm(val_loader):
        outputs = model(batch)
        
        all_z_brep_guided.append(outputs['z_brep'].cpu())
        all_z_pc_guided.append(outputs['z_pc'].cpu())
        all_z_brep_self.append(outputs['z_brep_self'].cpu())
        all_z_pc_self.append(outputs['z_pc_self'].cpu())
        all_z_text.append(outputs['z_text'].cpu())
        all_T_feat.append(outputs['T_feat'].cpu())
        all_Q_brep_self.append(outputs['Q_brep_self'].cpu())
        all_Q_pc_self.append(outputs['Q_pc_self'].cpu())
        all_confidence.append(outputs['confidence'].cpu())
        all_uids.extend(batch.get('sample_id', [f"sample_{i}" for i in range(len(outputs['z_brep']))]))

z_brep_guided = torch.cat(all_z_brep_guided, dim=0)
z_pc_guided = torch.cat(all_z_pc_guided, dim=0)
z_brep_self = torch.cat(all_z_brep_self, dim=0)
z_pc_self = torch.cat(all_z_pc_self, dim=0)
z_text = torch.cat(all_z_text, dim=0)
T_feat = torch.cat(all_T_feat, dim=0)
Q_brep_self = torch.cat(all_Q_brep_self, dim=0)
Q_pc_self = torch.cat(all_Q_pc_self, dim=0)
confidence = torch.cat(all_confidence, dim=0)

print(f"\nEmbeddings collected: {len(all_uids)} samples")
print(f"  z_brep_guided: {z_brep_guided.shape}")
print(f"  z_brep_self: {z_brep_self.shape}")
print(f"  T_feat: {T_feat.shape}")
print(f"  Q_brep_self: {Q_brep_self.shape}")

In [None]:
# 1. Self-grounding quality
print("\n" + "=" * 70)
print("SELF-GROUNDING QUALITY")
print("=" * 70)

self_cos_brep = compute_self_grounding_quality(z_brep_guided, z_brep_self)
self_cos_pc = compute_self_grounding_quality(z_pc_guided, z_pc_self)

print(f"BRep: {self_cos_brep:.4f} (target > 0.85)")
print(f"PC:   {self_cos_pc:.4f} (target > 0.85)")
print(f"Avg:  {(self_cos_brep + self_cos_pc) / 2:.4f}")

In [None]:
# 2. Query alignment
print("\n" + "=" * 70)
print("QUERY ALIGNMENT")
print("=" * 70)

q_align_brep = compute_query_alignment(T_feat, Q_brep_self, confidence)
q_align_pc = compute_query_alignment(T_feat, Q_pc_self, confidence)

print(f"BRep: {q_align_brep:.4f} (target > 0.7)")
print(f"PC:   {q_align_pc:.4f} (target > 0.7)")
print(f"Avg:  {(q_align_brep + q_align_pc) / 2:.4f}")

In [None]:
# 3. Retrieval metrics
def compute_recall_at_k(queries, keys, k_values=[1, 5, 10]):
    """Compute Recall@K for retrieval."""
    queries = F.normalize(queries, dim=-1)
    keys = F.normalize(keys, dim=-1)
    
    sim = queries @ keys.T  # (N, N)
    
    results = {}
    for k in k_values:
        _, topk_indices = sim.topk(k, dim=-1)
        correct = (topk_indices == torch.arange(len(queries)).unsqueeze(1)).any(dim=1)
        results[f'R@{k}'] = correct.float().mean().item() * 100
    
    return results

print("\n" + "=" * 70)
print("RETRIEVAL METRICS")
print("=" * 70)

# Text → BRep (guided)
recall_brep_guided = compute_recall_at_k(z_text, z_brep_guided)
print(f"\nText → BRep (guided):")
for k, v in recall_brep_guided.items():
    print(f"  {k}: {v:.2f}%")

# Text → BRep (self)
recall_brep_self = compute_recall_at_k(z_text, z_brep_self)
print(f"\nText → BRep (self):")
for k, v in recall_brep_self.items():
    print(f"  {k}: {v:.2f}%")

# Text → PC (guided)
recall_pc_guided = compute_recall_at_k(z_text, z_pc_guided)
print(f"\nText → PC (guided):")
for k, v in recall_pc_guided.items():
    print(f"  {k}: {v:.2f}%")

# Text → PC (self)
recall_pc_self = compute_recall_at_k(z_text, z_pc_self)
print(f"\nText → PC (self):")
for k, v in recall_pc_self.items():
    print(f"  {k}: {v:.2f}%")

# Gap
gap_brep = recall_brep_guided['R@1'] - recall_brep_self['R@1']
gap_pc = recall_pc_guided['R@1'] - recall_pc_self['R@1']
print(f"\nGap (guided - self):")
print(f"  BRep: {gap_brep:.2f}% (target < 10%)")
print(f"  PC:   {gap_pc:.2f}% (target < 10%)")

In [None]:
# 4. UMAP visualization
print("\nComputing UMAP embeddings...")

# Sample for visualization
n_samples = min(2000, len(z_text))
indices = np.random.choice(len(z_text), n_samples, replace=False)

# Combine embeddings
combined = torch.cat([
    z_brep_guided[indices],
    z_brep_self[indices],
    z_text[indices]
], dim=0).numpy()

# UMAP
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
coords = tsne.fit_transform(combined)

# Split back
brep_guided_coords = coords[:n_samples]
brep_self_coords = coords[n_samples:2*n_samples]
text_coords = coords[2*n_samples:]

# Plot
plt.figure(figsize=(10, 8))
plt.scatter(brep_guided_coords[:, 0], brep_guided_coords[:, 1], c='green', alpha=0.5, s=10, label='BRep (guided)')
plt.scatter(brep_self_coords[:, 0], brep_self_coords[:, 1], c='orange', alpha=0.5, s=10, label='BRep (self)')
plt.scatter(text_coords[:, 0], text_coords[:, 1], c='blue', alpha=0.5, s=10, label='Text')
plt.legend()
plt.title('t-SNE: Guided vs Self Embeddings')
plt.xlabel('Dimension 1')
plt.ylabel('Dimension 2')
plt.tight_layout()
plt.show()

In [None]:
# Summary
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"\nSelf-grounding quality:")
print(f"  BRep: {self_cos_brep:.4f}")
print(f"  PC:   {self_cos_pc:.4f}")
print(f"\nQuery alignment:")
print(f"  BRep: {q_align_brep:.4f}")
print(f"  PC:   {q_align_pc:.4f}")
print(f"\nText→BRep R@1:")
print(f"  Guided: {recall_brep_guided['R@1']:.2f}%")
print(f"  Self:   {recall_brep_self['R@1']:.2f}%")
print(f"  Gap:    {gap_brep:.2f}%")
print(f"\nText→PC R@1:")
print(f"  Guided: {recall_pc_guided['R@1']:.2f}%")
print(f"  Self:   {recall_pc_self['R@1']:.2f}%")
print(f"  Gap:    {gap_pc:.2f}%")