In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import sys
import math

# Import from our sudoku.py module
sys.path.append('.')
from sudoku import (
    SudokuDataset, SudokuMDM, MDMTrainer, MDMSampler,
    MaskSchedule, check_sudoku_valid, evaluate_samples
)

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)

# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")


Using device: cuda
PyTorch version: 2.9.1+cu128


In [2]:
# Load model
model_path = Path('./checkpoints/sudoku_mdm_best.pt')
model_config = {
    'vocab_size': 10,  # 0 (mask) + 1-9 (digits)
    'd_model': 384,
    'nhead': 12,
    'num_layers': 4,
    'dim_feedforward': 512,
    'dropout': 0.1,
    'max_seq_len': 81  # 9x9 Sudoku
}

# Create model
model = SudokuMDM(**model_config).to(device)
model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict'])



<All keys matched successfully>

In [3]:
def generate_conditional_entropy_data(
    model: SudokuMDM, puzzles: torch.Tensor, k=9, batch_size=256
):
    """
    Generate ground truth conditional entropy matrices for multiple puzzles (batched).
    
    For each position i, we compute H(j | i, c) = sum_xi P(xi | c) * H(j | xi, c)
    
    Args:
        model: Trained SudokuMDM
        puzzles: [num_puzzles, 81] tensor of input puzzles
        k: Number of top tokens to consider (max 9 for digits 1-9)
        batch_size: Number of (puzzle, position, token) tuples to process at once
    
    Returns:
        CE_matrices: [num_puzzles, 81, 81] tensor of conditional entropy values H(j | i)
    """
    # Ensure puzzles is on the correct device and has right shape
    if puzzles.dim() == 1:
        puzzles = puzzles.unsqueeze(0)
    
    puzzles = puzzles.to(device)
    num_puzzles = puzzles.size(0)

    # print(f"Generating conditional entropy data for {num_puzzles} puzzles (batched)...")
    # print(f"Using top-{k} tokens weighted by their probabilities")
    
    with torch.no_grad():
        # Initial forward pass to get P(j | c) for all puzzles
        initial_logits = model(puzzles)  # Shape: (num_puzzles, 81, 10)
        initial_probs = F.softmax(initial_logits, dim=-1)  # Shape: (num_puzzles, 81, 10)
        
        CE_matrices = torch.zeros((num_puzzles, 81, 81)).to(device)
        
        # Build list of all (puzzle_idx, position, token, weight) tuples to process
        all_queries = []  # List of (puzzle_idx, pos_i, token, weight)
        
        for p_idx in range(num_puzzles):
            # Find masked positions for this puzzle
            masked_positions = (puzzles[p_idx] == 0).nonzero(as_tuple=True)[0].tolist()
            
            for i in masked_positions:
                # Get top-k tokens and their probabilities for position i
                topk_probs, topk_tokens = torch.topk(initial_probs[p_idx, i, 1:], k)
                topk_tokens += 1  # Adjust indices since we excluded token 0
                
                # Normalize probabilities
                weights = topk_probs / topk_probs.sum()
                
                for idx, token in enumerate(topk_tokens):
                    all_queries.append((p_idx, i, token.item(), weights[idx].item()))
        
        # print(f"Total queries to process: {len(all_queries)}")
        
        # Process in batches
        num_batches = (len(all_queries) + batch_size - 1) // batch_size
        
        for batch_idx in range(num_batches):
            start = batch_idx * batch_size
            end = min(start + batch_size, len(all_queries))
            batch_queries = all_queries[start:end]
            
            # Create batched input - each query modifies its corresponding puzzle
            batch_puzzles = torch.zeros((len(batch_queries), 81), dtype=torch.long, device=device)
            
            for b, (p_idx, pos_i, token, weight) in enumerate(batch_queries):
                batch_puzzles[b] = puzzles[p_idx].clone()
                batch_puzzles[b, pos_i] = token
            
            # Batched forward pass
            logits = model(batch_puzzles)  # Shape: (B, 81, 10)
            probs = F.softmax(logits, dim=-1)  # Shape: (B, 81, 10)
            
            # H(j | x_i) for all j, for each sample in batch
            H_j_given_xi = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)  # Shape: (B, 81)
            
            # Accumulate weighted contributions to the correct puzzle's CE matrix
            for b, (p_idx, pos_i, token, weight) in enumerate(batch_queries):
                CE_matrices[p_idx, pos_i] += weight * H_j_given_xi[b]
            
            if (batch_idx + 1) % 20 == 0 or batch_idx == num_batches - 1:
                #print(f"  Processed batch {batch_idx + 1}/{num_batches}")
                ...
        # CE should be non-negative
        CE_matrices = torch.clamp(CE_matrices, min=0)
        
    return CE_matrices

In [4]:
class CEPredictor(nn.Module):
    """
    Lightweight module to predict conditional entropy matrix H(j | i, c) given a puzzle c.
    
    Uses frozen embeddings from a pre-trained MDM model, then adds a small trainable head
    to predict pairwise conditional entropy.
    
    Architecture:
    1. Get contextualized embeddings from frozen MDM (no grad)
    2. Project to smaller dimension
    3. Bilinear pairwise prediction head
    
    Output: [batch_size, 81, 81] matrix where entry (i, j) = H(X_j | X_i, c)
    """
    def __init__(
        self,
        mdm_model: SudokuMDM,
        hidden_dim: int = 64,
        dropout: float = 0.1,
        freeze_mdm: bool = True
    ):
        super(CEPredictor, self).__init__()
        
        self.mdm = mdm_model
        self.mdm_dim = mdm_model.d_model  # e.g., 384
        self.hidden_dim = hidden_dim
        
        # Freeze MDM parameters
        if freeze_mdm:
            for param in self.mdm.parameters():
                param.requires_grad = False
        
        # Project MDM embeddings to smaller dimension
        self.proj = nn.Linear(self.mdm_dim, hidden_dim)
        
        # Separate projections for "source" (i) and "target" (j) roles
        self.query_proj = nn.Linear(hidden_dim, hidden_dim)  # "if I reveal i..."
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)    # "...entropy of j"
        
        # Small MLP head for final prediction
        self.ce_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        # Only init the trainable layers (not MDM)
        for module in [self.proj, self.query_proj, self.key_proj, self.ce_head]:
            for p in module.parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)
    
    def get_mdm_embeddings(self, x: torch.Tensor) -> torch.Tensor:
        """
        Extract contextualized embeddings from the MDM.
        
        Args:
            x: [batch_size, 81] - puzzle tokens
            
        Returns:
            h: [batch_size, 81, mdm_dim] - contextualized representations
        """
        # Run through MDM's embedding + positional encoding + transformer
        # but stop before the output head
        with torch.no_grad():
            h = self.mdm.embedding(x)  # [B, 81, d_model]
            h = self.mdm.pos_encoder(h)  # Add positional encoding
            h = self.mdm.transformer(h)  # [B, 81, d_model]
        return h
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [batch_size, 81] - puzzle with token indices (0 = mask, 1-9 = digits)
            
        Returns:
            ce_matrix: [batch_size, 81, 81] - predicted H(j | i, c) for all pairs
        """
        # Get frozen MDM embeddings
        h = self.get_mdm_embeddings(x)  # [B, 81, mdm_dim]
        h.requires_grad = True
        # Project to smaller dimension
        h = self.proj(h)  # [B, 81, hidden_dim]
        
        # Compute pairwise CE predictions
        # Query: "if I reveal position i..."
        # Key: "...what is the entropy of position j?"
        queries = self.query_proj(h)  # [B, 81, hidden_dim]
        keys = self.key_proj(h)       # [B, 81, hidden_dim]
        
        # Pairwise interaction: element-wise product
        # [B, 81, 1, hidden_dim] * [B, 1, 81, hidden_dim] -> [B, 81, 81, hidden_dim]
        interaction = queries.unsqueeze(2) * keys.unsqueeze(1)
        
        # Predict CE from interaction features
        ce_matrix = self.ce_head(interaction).squeeze(-1)  # [B, 81, 81]
        
        # CE must be non-negative
        ce_matrix = F.softplus(ce_matrix)
        
        return ce_matrix


class CEPredictorTrainer:
    """
    Trainer for the CE predictor with on-the-fly CE generation.
    
    Instead of using pre-computed CE matrices, this trainer:
    1. Takes solutions from a DataLoader
    2. Applies random masking to create puzzles
    3. Generates ground truth CE matrices on-the-fly
    4. Trains the predictor on these
    """
    
    def __init__(
        self,
        predictor: CEPredictor,
        dataset: SudokuDataset,
        mask_schedule: MaskSchedule,
        lr: float = 1e-3,
        weight_decay: float = 0.01,
        batch_size: int = 32,
        ce_batch_size: int = 256,
        k: int = 9
    ):
        self.predictor = predictor
        self.dataset = dataset
        self.mask_schedule = mask_schedule
        self.batch_size = batch_size
        self.ce_batch_size = ce_batch_size  # batch size for CE generation
        self.k = k  # top-k tokens for CE computation
        
        # Create DataLoader
        self.dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True
        )
        self.dataloader_iter = iter(self.dataloader)
        
        # Only optimize the trainable parameters (not frozen MDM)
        trainable_params = [p for p in predictor.parameters() if p.requires_grad]
        print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
        
        self.optimizer = torch.optim.AdamW(
            trainable_params,
            lr=lr,
            weight_decay=weight_decay
        )
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=1000
        )
        self.train_losses = []
    
    def _get_batch(self):
        """Get next batch from dataloader, reset if exhausted."""
        try:
            batch = next(self.dataloader_iter)
        except StopIteration:
            self.dataloader_iter = iter(self.dataloader)
            batch = next(self.dataloader_iter)
        return batch
    
    def _apply_random_masking(self, solutions: torch.Tensor) -> torch.Tensor:
        """
        Apply random masking to solutions to create puzzles.
        
        Args:
            solutions: [batch_size, 81] - complete Sudoku solutions
            
        Returns:
            puzzles: [batch_size, 81] - masked puzzles
        """
        puzzles = solutions.clone()
        t = np.random.randint(1, self.mask_schedule.total_steps + 1)  # At least 1 mask
        mask_ratio = self.mask_schedule.get_mask_ratio(t)
        puzzles, _ = self.mask_schedule.apply_mask(puzzles, mask_ratio)
        return puzzles
    
    def train_step(self) -> float:
        """
        Single training step with on-the-fly CE generation.
        
        Returns:
            loss value
        """
        self.predictor.train()
        self.optimizer.zero_grad()
        
        # Get batch of solutions
        _, solutions = self._get_batch()
        solutions = solutions.to(next(self.predictor.parameters()).device)
        
        # Apply random masking to create puzzles
        puzzles = self._apply_random_masking(solutions)
        
        # Generate ground truth CE matrices on-the-fly
        target_ce = generate_conditional_entropy_data(
            self.predictor.mdm,
            puzzles,
            k=self.k,
            batch_size=self.ce_batch_size
        )
        
        # Forward pass
        pred_ce = self.predictor(puzzles)  # [B, 81, 81]
        
        # Only compute loss on masked positions (rows where puzzle[i] == 0)
        mask = (puzzles == 0)  # [B, 81]
        row_mask = mask.unsqueeze(2).expand_as(pred_ce)  # [B, 81, 81]
        
        # MSE loss on valid entries
        if row_mask.any():
            loss = F.mse_loss(pred_ce[row_mask], target_ce[row_mask])
            loss.backward()
        else:
            loss = torch.tensor(0.0, device=puzzles.device)
        
        pred_ce: torch.Tensor

        # Backward pass
        
        torch.nn.utils.clip_grad_norm_(self.predictor.parameters(), 1.0)
        self.optimizer.step()
        self.scheduler.step()
        
        self.train_losses.append(loss.item())
        return loss.item()
    
    def train_epoch(self, steps_per_epoch: int = None) -> float:
        """
        Train for one epoch (or specified number of steps).
        
        Args:
            steps_per_epoch: Number of steps. If None, use len(dataloader).
            
        Returns:
            Average loss for the epoch
        """
        if steps_per_epoch is None:
            steps_per_epoch = len(self.dataloader)
        
        epoch_losses = []
        for step in range(steps_per_epoch):
            loss = self.train_step()
            epoch_losses.append(loss)
            
            if (step + 1) % 10 == 0:
                avg_loss = sum(epoch_losses[-10:]) / min(10, len(epoch_losses))
                print(f"  Step {step + 1}/{steps_per_epoch}, Loss: {avg_loss:.6f}")
        
        return sum(epoch_losses) / len(epoch_losses)
    
    @torch.no_grad()
    def evaluate(self, num_samples: int = 100) -> dict:
        """Evaluate predictor on random validation samples."""
        self.predictor.eval()
        
        # Generate some validation puzzles
        """TypeError: expected Tensor as element 0 in argument 0, but got tuple"""
        solutions = torch.stack([self.dataset[i][1] for i in range(num_samples)])
        solutions = solutions.to(next(self.predictor.parameters()).device)
        puzzles = self._apply_random_masking(solutions)
        
        # Generate ground truth CE
        target_ce = generate_conditional_entropy_data(
            self.predictor.mdm,
            puzzles,
            k=self.k,
            batch_size=self.ce_batch_size
        )
        
        # Predict
        pred_ce = self.predictor(puzzles)
        
        mask = (puzzles == 0)
        row_mask = mask.unsqueeze(2).expand_as(pred_ce)
        
        if row_mask.any():
            mse = F.mse_loss(pred_ce[row_mask], target_ce[row_mask]).item()
            mae = F.l1_loss(pred_ce[row_mask], target_ce[row_mask]).item()
        else:
            mse, mae = 0.0, 0.0
        
        return {'mse': mse, 'mae': mae}


# Create predictor that uses the frozen MDM embeddings
ce_predictor = CEPredictor(
    mdm_model=model,
    hidden_dim=64,
    dropout=0.1,
    freeze_mdm=True
).to(device)

# Count only trainable parameters
trainable_params = sum(p.numel() for p in ce_predictor.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in ce_predictor.parameters())
print(f"CE Predictor:")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Total parameters (including frozen MDM): {total_params:,}")
print(f"  MDM parameters (frozen): {total_params - trainable_params:,}")

CE Predictor:
  Trainable parameters: 37,185
  Total parameters (including frozen MDM): 4,195,531
  MDM parameters (frozen): 4,158,346


In [5]:
# train!
dataset = SudokuDataset(data_path='./data/sudoku.csv', num_samples=100000)
trainer = CEPredictorTrainer(
    predictor=ce_predictor,
    dataset=dataset,
    mask_schedule=MaskSchedule(schedule_type='linear', total_steps=40),
    lr=1e-3,
    weight_decay=0.01,
    batch_size=32,
    ce_batch_size=256,
    k=9
)

num_epochs = 10
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    avg_loss = trainer.train_epoch(steps_per_epoch=100)
    print(f"  Average Loss: {avg_loss:.6f}")
    
    eval_metrics = trainer.evaluate(num_samples=100)
    print(f"  Eval MSE: {eval_metrics['mse']:.6f}, MAE: {eval_metrics['mae']:.6f}")

# Save trained predictor
torch.save(
    ce_predictor.state_dict(),
    './checkpoints/ce_predictor_sudoku_mdm.pt'
)


Loading Sudoku from CSV: ./data/sudoku.csv (start=0, n=100000)
✓ Loaded 100000 puzzles from CSV: ./data/sudoku.csv
Trainable parameters: 37,185
Epoch 1/10
  Step 10/100, Loss: 9948705.338833
  Step 20/100, Loss: 2.622402
  Step 30/100, Loss: 3.006485
  Step 40/100, Loss: 2.504423
  Step 50/100, Loss: 2.730825
  Step 60/100, Loss: 3.220881
  Step 70/100, Loss: 2.807148
  Step 80/100, Loss: 2.405799
  Step 90/100, Loss: 2.623364
  Step 100/100, Loss: 2.519560
  Average Loss: 994872.977972
  Eval MSE: 2.081733, MAE: 1.394552
Epoch 2/10
  Step 10/100, Loss: 2.638185
  Step 20/100, Loss: 2.380998
  Step 30/100, Loss: 2.245148
  Step 40/100, Loss: 2.519163
  Step 50/100, Loss: 2.418856
  Step 60/100, Loss: 2.624129
  Step 70/100, Loss: 2.690426
  Step 80/100, Loss: 2.799759
  Step 90/100, Loss: 2.730187
  Step 100/100, Loss: 2.728065
  Average Loss: 2.577492
  Eval MSE: 1.671863, MAE: 1.138784
Epoch 3/10
  Step 10/100, Loss: 2.227292
  Step 20/100, Loss: 2.932367
  Step 30/100, Loss: 2.57852

In [6]:
dataset = SudokuDataset(data_path='./data/sudoku.csv', num_samples=100000)

Loading Sudoku from CSV: ./data/sudoku.csv (start=0, n=100000)
✓ Loaded 100000 puzzles from CSV: ./data/sudoku.csv
