# CLIP4CAD-GFA v4.2 Training

**Key Innovation: Conditional Self-Query Generation with Curriculum Learning**

## Problem with v4
The decoder doesn't know what T_feat "looks like". Query distillation loss alone isn't enough.

## Solution: Curriculum Learning
During training, SHOW the model T_feat as a hint, then gradually remove hints.

### Curriculum Schedule
- **Epoch 1-3**: 0.1 dropout (90% samples get hints) - learn output distribution
- **Epoch 4-7**: 0.3 dropout (70% samples get hints) - start independence
- **Epoch 8-11**: 0.5 dropout (50% samples get hints) - balanced
- **Epoch 12-15**: 0.7 dropout (30% samples get hints) - mostly independent
- **Stage 2**: 1.0 dropout (0% hints) - fully independent

### New Loss: Distribution Matching
- Matches batch statistics (mean, std) of Q_self to T_feat
- Regularizes the feature space during curriculum transition

## Loss Weights
| Stage | λ_self | λ_query | λ_dist | λ_detail |
|-------|--------|---------|--------|----------|
| 1 | 0.1 | **1.5** | 0.3 | 0.0 |
| 2 | 0.3 | 1.0 | 0.2 | 0.3 |

## Success Criteria
- Self-cos BRep ≥ 0.75 (should NOT collapse!)
- Query alignment ≥ 0.70
- Gap (guided - self) < 10%

In [None]:
# Cell 1: Imports and Setup
import sys
sys.path.insert(0, '..')

import os
import gc
import math
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from tqdm.auto import tqdm
import numpy as np
from pathlib import Path
import json

from clip4cad.models import CLIP4CAD_GFA_v4_2, GFAv4_2Config, get_cond_dropout
from clip4cad.losses import GFAv4_2Loss
from clip4cad.losses.gfa_v4_2_losses import compute_self_grounding_quality, compute_query_alignment

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Cell 2: Data Paths

DATA_ROOT = Path("d:/Defect_Det/MMCAD/data")
PC_FILE = Path("c:/Users/User/Desktop/pc_embeddings_full.h5")
BREP_FILE = Path("c:/Users/User/Desktop/brep_features.h5")
TEXT_FILE = Path("c:/Users/User/Desktop/text_embeddings.h5")
OUTPUT_DIR = Path("../outputs/gfa_v4_2")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"Data root: {DATA_ROOT}")
print(f"PC file: {PC_FILE} (exists: {PC_FILE.exists()})")
print(f"BRep file: {BREP_FILE} (exists: {BREP_FILE.exists()})")
print(f"Text file: {TEXT_FILE} (exists: {TEXT_FILE.exists()})")
print(f"Output: {OUTPUT_DIR}")

In [None]:
# Cell 3: Load Data using GFAMappedDataset

from clip4cad.data.gfa_dataset import GFAMappedDataset

print("Loading datasets...")
print("=" * 60)

# Train dataset - LOAD TO RAM for fast training
print("\n[1/2] Loading TRAIN dataset to RAM...")
train_dataset = GFAMappedDataset(
    data_root=str(DATA_ROOT),
    split="train",
    pc_file=str(PC_FILE),
    text_file=str(TEXT_FILE),
    brep_file=str(BREP_FILE),
    num_rotations=1,
    load_to_memory=True,
    use_live_text=False,
)
print(f"Train: {len(train_dataset):,} samples in RAM")

# Val dataset - ON DISK (saves RAM)
print("\n[2/2] Loading VAL dataset (on disk)...")
val_dataset = GFAMappedDataset(
    data_root=str(DATA_ROOT),
    split="val",
    pc_file=str(PC_FILE),
    text_file=str(TEXT_FILE),
    brep_file=str(BREP_FILE),
    num_rotations=1,
    load_to_memory=False,
    use_live_text=False,
)
print(f"Val: {len(val_dataset):,} samples on disk")

print("\n" + "=" * 60)
print("DATASETS READY!")
print("=" * 60)

In [None]:
# Cell 4: Verify Dataset

sample = train_dataset[0]
print(f"Sample keys: {list(sample.keys())}")
print(f"  brep_face_features: {sample['brep_face_features'].shape}")
print(f"  brep_edge_features: {sample['brep_edge_features'].shape}")
print(f"  brep_face_mask: {sample['brep_face_mask'].shape}")
print(f"  pc_features: {sample['pc_features'].shape}")
print(f"  desc_embedding: {sample['desc_embedding'].shape}")

In [None]:
# Cell 5: Create Model
# NOTE: Using lighter architecture to match v2 training time
# Original v4.2 design had 4+4 layers for BRep, but that's 6x slower than v2!

config = GFAv4_2Config(
    d_face=48,          # FSQ face features
    d_edge=12,          # FSQ edge features
    d_pc=1024,          # ShapeLLM features
    d_text=3072,        # Phi-4-mini features
    d_unified=256,
    d_proj=128,
    d_ground=128,
    num_slots=12,
    num_detail_queries=8,
    num_heads=8,
    num_parser_layers=2,
    # LIGHTER ARCHITECTURE (was 4+4 for BRep, 2+2 for PC)
    brep_encoder_layers=2,  # Reduced from 4
    brep_decoder_layers=2,  # Reduced from 4
    pc_encoder_layers=1,    # Reduced from 2
    pc_decoder_layers=2,    # Keep at 2
    dropout=0.1,
)

model = CLIP4CAD_GFA_v4_2(config).to(device)
print(f"Model parameters: {model.count_parameters():,}")
print(f"Trainable parameters: {model.count_parameters(trainable_only=True):,}")
print(f"\nArchitecture (lighter for faster training):")
print(f"  BRep: {config.brep_encoder_layers} encoder + {config.brep_decoder_layers} decoder layers")
print(f"  PC:   {config.pc_encoder_layers} encoder + {config.pc_decoder_layers} decoder layers")

In [None]:
# Cell 6: Training Configuration

from clip4cad.data.gfa_dataset import gfa_collate_fn

# Hyperparameters
BATCH_SIZE = 512
STAGE1_EPOCHS = 15
STAGE2_EPOCHS = 20
STAGE1_LR = 3e-5
STAGE2_LR = 1e-5
WARMUP_EPOCHS = 3
MAX_GRAD_NORM = 1.0

# Loss weights - v4.2 with curriculum learning
# Stage 1: Heavy query distillation with conditioning hints
STAGE1_LAMBDA_SELF = 0.1
STAGE1_LAMBDA_QUERY = 1.5       # Heavy query distillation
STAGE1_LAMBDA_EMBED = 0.3
STAGE1_LAMBDA_DIST = 0.3        # Distribution matching
STAGE1_LAMBDA_DETAIL = 0.0

# Stage 2: Balanced + hard negatives (no hints)
STAGE2_LAMBDA_SELF = 0.3
STAGE2_LAMBDA_QUERY = 1.0
STAGE2_LAMBDA_EMBED = 0.3
STAGE2_LAMBDA_DIST = 0.2
STAGE2_LAMBDA_DETAIL = 0.3

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    drop_last=True,
    collate_fn=gfa_collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
    collate_fn=gfa_collate_fn,
)

print(f"Batch size: {BATCH_SIZE}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Total epochs: {STAGE1_EPOCHS + STAGE2_EPOCHS}")
print(f"\nStage 1 weights: self={STAGE1_LAMBDA_SELF}, query={STAGE1_LAMBDA_QUERY}, dist={STAGE1_LAMBDA_DIST}")
print(f"Stage 2 weights: self={STAGE2_LAMBDA_SELF}, query={STAGE2_LAMBDA_QUERY}, dist={STAGE2_LAMBDA_DIST}")

In [None]:
# Cell 7: Initialize Optimizer, Loss, Scheduler, and Hard Negative Miner

from clip4cad.training.hard_negative_mining import HardNegativeMiner

optimizer = AdamW(
    model.parameters(),
    lr=STAGE1_LR,
    weight_decay=0.01,
    betas=(0.9, 0.999),
)

criterion = GFAv4_2Loss(
    lambda_self=STAGE1_LAMBDA_SELF,
    lambda_query=STAGE1_LAMBDA_QUERY,
    lambda_embed=STAGE1_LAMBDA_EMBED,
    lambda_dist=STAGE1_LAMBDA_DIST,
    lambda_detail=STAGE1_LAMBDA_DETAIL,
)

scaler = GradScaler()

# Learning rate scheduler with warmup
total_epochs = STAGE1_EPOCHS + STAGE2_EPOCHS
warmup_steps = WARMUP_EPOCHS * len(train_loader)
total_steps = total_epochs * len(train_loader)

def lr_lambda(step):
    if step < warmup_steps:
        return step / max(warmup_steps, 1)
    progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
    return max(1e-6 / STAGE1_LR, 0.5 * (1 + math.cos(math.pi * progress)))

scheduler = LambdaLR(optimizer, lr_lambda)

# Hard negative miner (used in Stage 2)
hard_neg_miner = HardNegativeMiner(
    model=model,
    train_dataloader=train_loader,
    cache_dir=str(OUTPUT_DIR / "hard_negatives"),
    k=20,
    text_sim_threshold=0.8,
    min_negatives=1,
    max_negatives=10,
    use_faiss=True,
    device=str(device),
)
hard_negatives = None
MINE_EVERY_N_EPOCHS = 5

print("Optimizer, loss, scheduler, and hard negative miner initialized.")
print(f"Query distillation weight: {STAGE1_LAMBDA_QUERY} (v4.2 curriculum learning)")
print(f"Distribution matching weight: {STAGE1_LAMBDA_DIST}")

In [None]:
# Cell 8: Training State

# Training state
global_step = 0
best_val_loss = float('inf')
best_self_cosine = 0.0
best_query_align = 0.0
current_stage = 1

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'self_cosine_brep': [],
    'self_cosine_pc': [],
    'query_align_brep': [],
    'query_align_pc': [],
    'cond_dropout': [],
    'dist_loss': [],
}

print("Training state initialized.")

In [None]:
# Cell 9: Training Loop with Curriculum Learning

print("Starting training with curriculum learning...")
print("="*70)
print("Curriculum: hint rate 90% → 70% → 50% → 30% → 0%")
print("="*70)

for epoch in range(1, total_epochs + 1):
    # Stage transition
    if epoch == STAGE1_EPOCHS + 1:
        print("\n" + "="*70)
        print("TRANSITIONING TO STAGE 2 (no hints)")
        print("="*70)
        current_stage = 2
        
        # Update loss weights
        criterion.update_weights(
            lambda_self=STAGE2_LAMBDA_SELF,
            lambda_query=STAGE2_LAMBDA_QUERY,
            lambda_embed=STAGE2_LAMBDA_EMBED,
            lambda_dist=STAGE2_LAMBDA_DIST,
            lambda_detail=STAGE2_LAMBDA_DETAIL,
        )
        print(f"Updated loss weights: self={STAGE2_LAMBDA_SELF}, query={STAGE2_LAMBDA_QUERY}, dist={STAGE2_LAMBDA_DIST}, detail={STAGE2_LAMBDA_DETAIL}")
        
        # Reduce learning rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = STAGE2_LR
        print(f"Reduced LR to {STAGE2_LR}")
        
        # Save Stage 1 checkpoint
        torch.save({
            'epoch': epoch - 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_self_cosine': best_self_cosine,
            'best_query_align': best_query_align,
        }, OUTPUT_DIR / 'checkpoint_stage1_final.pt')
        print(f"Saved Stage 1 checkpoint")
        
        # Mine hard negatives at start of Stage 2
        print("\nMining hard negatives for Stage 2...")
        hard_negatives = hard_neg_miner.mine(epoch=epoch)
        print(f"Mined hard negatives for {len(hard_negatives)} samples")
    
    # Re-mine hard negatives every N epochs in Stage 2
    if current_stage == 2 and epoch > STAGE1_EPOCHS + 1:
        if (epoch - STAGE1_EPOCHS - 1) % MINE_EVERY_N_EPOCHS == 0:
            print(f"\nRe-mining hard negatives (epoch {epoch})...")
            hard_negatives = hard_neg_miner.mine(epoch=epoch)
            print(f"Re-mined hard negatives for {len(hard_negatives)} samples")
    
    # ═══════════════════════════════════════════════════════════════════════
    # CURRICULUM: Update conditioning dropout rate
    # ═══════════════════════════════════════════════════════════════════════
    cond_drop = get_cond_dropout(epoch, current_stage)
    model.set_cond_dropout(cond_drop)
    hint_rate = (1 - cond_drop) * 100
    
    # Train epoch
    model.train()
    epoch_loss = 0.0
    epoch_self_cos_brep = []
    epoch_self_cos_pc = []
    epoch_query_align_brep = []
    epoch_query_align_pc = []
    epoch_dist_loss = []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch} (Stage {current_stage}, {hint_rate:.0f}% hints)")
    for batch_idx, batch in enumerate(pbar):
        # Get hard negatives for this batch (if in Stage 2)
        batch_hard_negs = None
        if current_stage == 2 and hard_negatives is not None:
            batch_size = batch['brep_face_features'].shape[0]
            start_idx = batch_idx * BATCH_SIZE
            batch_hard_negs = []
            for i in range(batch_size):
                sample_idx = start_idx + i
                if sample_idx in hard_negatives:
                    batch_hard_negs.append(hard_negatives[sample_idx])
                else:
                    batch_hard_negs.append(None)
        
        with autocast():
            outputs = model(batch)
            loss, loss_dict = criterion(outputs, hard_negatives=batch_hard_negs, stage=current_stage)
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        global_step += 1
        epoch_loss += loss_dict['total']
        epoch_dist_loss.append(loss_dict.get('dist', 0))
        
        # Compute self-grounding quality
        if outputs.get('z_brep') is not None and outputs.get('z_brep_self') is not None:
            cos_brep = compute_self_grounding_quality(
                outputs['z_brep'].detach(),
                outputs['z_brep_self'].detach()
            )
            epoch_self_cos_brep.append(cos_brep)
        
        if outputs.get('z_pc') is not None and outputs.get('z_pc_self') is not None:
            cos_pc = compute_self_grounding_quality(
                outputs['z_pc'].detach(),
                outputs['z_pc_self'].detach()
            )
            epoch_self_cos_pc.append(cos_pc)
        
        # Compute query alignment
        if outputs.get('T_feat') is not None and outputs.get('Q_brep_self') is not None:
            q_align_brep = compute_query_alignment(
                outputs['T_feat'].detach(),
                outputs['Q_brep_self'].detach(),
                outputs['confidence'].detach()
            )
            epoch_query_align_brep.append(q_align_brep)
        
        if outputs.get('T_feat') is not None and outputs.get('Q_pc_self') is not None:
            q_align_pc = compute_query_alignment(
                outputs['T_feat'].detach(),
                outputs['Q_pc_self'].detach(),
                outputs['confidence'].detach()
            )
            epoch_query_align_pc.append(q_align_pc)
        
        # Update progress bar
        postfix = {
            'loss': f"{loss_dict['total']:.3f}",
            'guided': f"{loss_dict['guided']:.3f}",
            'query': f"{loss_dict.get('query', 0):.3f}",
            'dist': f"{loss_dict.get('dist', 0):.3f}",
            'cos': f"{epoch_self_cos_brep[-1]:.3f}" if epoch_self_cos_brep else "N/A",
            'q_align': f"{epoch_query_align_brep[-1]:.3f}" if epoch_query_align_brep else "N/A",
        }
        if current_stage == 2:
            postfix['detail'] = f"{loss_dict.get('detail', 0):.3f}"
        pbar.set_postfix(postfix)
    
    # Epoch summary
    avg_loss = epoch_loss / len(train_loader)
    avg_cos_brep = sum(epoch_self_cos_brep) / len(epoch_self_cos_brep) if epoch_self_cos_brep else 0
    avg_cos_pc = sum(epoch_self_cos_pc) / len(epoch_self_cos_pc) if epoch_self_cos_pc else 0
    avg_q_align_brep = sum(epoch_query_align_brep) / len(epoch_query_align_brep) if epoch_query_align_brep else 0
    avg_q_align_pc = sum(epoch_query_align_pc) / len(epoch_query_align_pc) if epoch_query_align_pc else 0
    avg_dist_loss = sum(epoch_dist_loss) / len(epoch_dist_loss) if epoch_dist_loss else 0
    
    history['train_loss'].append(avg_loss)
    history['self_cosine_brep'].append(avg_cos_brep)
    history['self_cosine_pc'].append(avg_cos_pc)
    history['query_align_brep'].append(avg_q_align_brep)
    history['query_align_pc'].append(avg_q_align_pc)
    history['cond_dropout'].append(cond_drop)
    history['dist_loss'].append(avg_dist_loss)
    
    if avg_cos_brep > best_self_cosine:
        best_self_cosine = avg_cos_brep
    if avg_q_align_brep > best_query_align:
        best_query_align = avg_q_align_brep
    
    print(f"\nEpoch {epoch}: Loss={avg_loss:.4f}, Self-cos BRep={avg_cos_brep:.4f}, PC={avg_cos_pc:.4f}")
    print(f"  Query-align BRep={avg_q_align_brep:.4f}, PC={avg_q_align_pc:.4f}")
    print(f"  cond_dropout={cond_drop:.1f} ({hint_rate:.0f}% hints), dist_loss={avg_dist_loss:.4f}")
    print(f"  Best: self-cos={best_self_cosine:.4f}, query-align={best_query_align:.4f}")
    
    # Validation every 5 epochs
    if epoch % 5 == 0:
        model.eval()
        val_loss = 0.0
        val_cos_brep = []
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                with autocast():
                    outputs = model(batch)
                    loss, loss_dict = criterion(outputs, stage=current_stage)
                val_loss += loss_dict['total']
                
                if outputs.get('z_brep') is not None and outputs.get('z_brep_self') is not None:
                    cos_brep = compute_self_grounding_quality(
                        outputs['z_brep'],
                        outputs['z_brep_self']
                    )
                    val_cos_brep.append(cos_brep)
        
        avg_val_loss = val_loss / len(val_loader)
        avg_val_cos = sum(val_cos_brep) / len(val_cos_brep) if val_cos_brep else 0
        
        history['val_loss'].append(avg_val_loss)
        print(f"Validation: Loss={avg_val_loss:.4f}, Self-cos={avg_val_cos:.4f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'best_val_loss': best_val_loss,
                'best_self_cosine': best_self_cosine,
                'best_query_align': best_query_align,
            }, OUTPUT_DIR / 'checkpoint_best.pt')
            print("Saved best model!")
    
    # Save checkpoint every 5 epochs
    if epoch % 5 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_self_cosine': best_self_cosine,
            'best_query_align': best_query_align,
        }, OUTPUT_DIR / f'checkpoint_epoch{epoch}.pt')
    
    # Clear cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

print("\n" + "="*70)
print("Training Complete!")
print(f"Best self-grounding cosine: {best_self_cosine:.4f}")
print(f"Best query alignment: {best_query_align:.4f}")
print(f"Best validation loss: {best_val_loss:.4f}")
print("="*70)

In [None]:
# Cell 10: Save Final Model

torch.save({
    'model_state_dict': model.state_dict(),
    'config': config.__dict__,
    'best_self_cosine': best_self_cosine,
    'best_query_align': best_query_align,
    'history': history,
}, OUTPUT_DIR / 'clip4cad_gfa_v4_2_final.pt')

print(f"Final model saved to {OUTPUT_DIR / 'clip4cad_gfa_v4_2_final.pt'}")

In [None]:
# Cell 11: Plot Training History

import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Loss plot
ax = axes[0, 0]
ax.plot(history['train_loss'], label='Train Loss')
if history['val_loss']:
    val_epochs = list(range(5, len(history['train_loss']) + 1, 5))[:len(history['val_loss'])]
    ax.plot(val_epochs, history['val_loss'], 'o-', label='Val Loss')
ax.axvline(x=STAGE1_EPOCHS, color='r', linestyle='--', label='Stage 2')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Loss')
ax.legend()
ax.grid(True)

# Self-grounding quality
ax = axes[0, 1]
ax.plot(history['self_cosine_brep'], label='BRep')
ax.plot(history['self_cosine_pc'], label='PC')
ax.axvline(x=STAGE1_EPOCHS, color='r', linestyle='--', label='Stage 2')
ax.axhline(y=0.7, color='g', linestyle=':', label='Target')
ax.set_xlabel('Epoch')
ax.set_ylabel('Cosine Similarity')
ax.set_title('Self-Grounding Quality')
ax.legend()
ax.grid(True)

# Query alignment
ax = axes[0, 2]
ax.plot(history['query_align_brep'], label='BRep')
ax.plot(history['query_align_pc'], label='PC')
ax.axvline(x=STAGE1_EPOCHS, color='r', linestyle='--', label='Stage 2')
ax.axhline(y=0.7, color='g', linestyle=':', label='Target')
ax.set_xlabel('Epoch')
ax.set_ylabel('Cosine Similarity')
ax.set_title('Query Alignment (T_feat vs Q_self)')
ax.legend()
ax.grid(True)

# Curriculum (cond_dropout)
ax = axes[1, 0]
ax.plot(history['cond_dropout'], 'b-', linewidth=2)
ax.axvline(x=STAGE1_EPOCHS, color='r', linestyle='--', label='Stage 2')
ax.set_xlabel('Epoch')
ax.set_ylabel('Conditioning Dropout Rate')
ax.set_title('Curriculum Schedule (0=all hints, 1=no hints)')
ax.set_ylim(-0.05, 1.05)
ax.grid(True)

# Distribution loss
ax = axes[1, 1]
ax.plot(history['dist_loss'], label='Distribution Loss')
ax.axvline(x=STAGE1_EPOCHS, color='r', linestyle='--', label='Stage 2')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Distribution Matching Loss')
ax.legend()
ax.grid(True)

# BRep vs PC comparison
ax = axes[1, 2]
ax.plot(history['self_cosine_brep'], label='Self-cos BRep')
ax.plot(history['query_align_brep'], label='Query-align BRep', linestyle='--')
ax.axvline(x=STAGE1_EPOCHS, color='r', linestyle='--', label='Stage 2')
ax.set_xlabel('Epoch')
ax.set_ylabel('Cosine Similarity')
ax.set_title('BRep: Self-cos vs Query-align')
ax.legend()
ax.grid(True)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'training_history.png', dpi=150)
plt.show()

print(f"Plot saved to {OUTPUT_DIR / 'training_history.png'}")

In [None]:
# Cell 12: Summary

print("\n" + "="*70)
print("TRAINING SUMMARY - GFA v4.2 (Curriculum Learning)")
print("="*70)
print(f"\nFinal metrics:")
print(f"  Self-cos BRep: {history['self_cosine_brep'][-1]:.4f}")
print(f"  Self-cos PC:   {history['self_cosine_pc'][-1]:.4f}")
print(f"  Query-align BRep: {history['query_align_brep'][-1]:.4f}")
print(f"  Query-align PC:   {history['query_align_pc'][-1]:.4f}")
print(f"\nBest metrics:")
print(f"  Best self-cosine: {best_self_cosine:.4f}")
print(f"  Best query-align: {best_query_align:.4f}")
print(f"  Best val loss: {best_val_loss:.4f}")
print(f"\nModel saved to: {OUTPUT_DIR / 'clip4cad_gfa_v4_2_final.pt'}")