# Document aligner project
## NEW Model Implementation
- U-Net with skip connections
- Flow predictor, differentiable warping
- Standard encoder-decoder

In [None]:
from dataset_loader import get_dataloaders, visualize_batch
from reconstruction_model import *
from reconstruction_model import MaskedL1Loss, MaskedMSELoss, SSIMLoss, UVReconstructionLoss
from training_val_dewarpnet import train_one_epoch, validate

import time 
import os
import glob
from typing import Dict, List, Tuple, Optional, Callable
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np


In [None]:
# Create dataloaders
train_loader, val_loader = get_dataloaders(
    data_dir='/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep',
    batch_size=8,
    train_split=0.8,
    img_size=(512, 512), #(256, 256)
)

# Visualize samples
sample_batch = next(iter(train_loader))
# visualize_batch(sample_batch, num_samples=4)

# Model description
- TODO

Training:
- [ ] Additional metrics (PSNR, SSIM)
- [ ] Learning rate scheduling
- [ ] Gradient clipping
- [ ] Mixed precision training
- [ ] Logging to tensorboard/wandb

Main training loop
- [ ] Implement better model architecture
- [ ] Try different loss functions
- [ ] Add learning rate scheduling
- [ ] Implement early stopping
- [ ] Add visualization and logging
- [ ] Experiment with data augmentation
- [ ] Use pretrained model from HuggingFace
- [ ] Enable MASKED LOSSES
- [ ] Use DEPTH
- [ ] Use UV
- [ ] Use BORDER
- [ ] Try different OPTIMIZERS

## Training Loop

In [None]:
# Main training loop 

def main():
    """
    Main training loop 

    *** CHECK TODO LIST ***

    """
    
    # Configuration
    DATA_DIR = '/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep'
    BATCH_SIZE = 8
    NUM_EPOCHS = 50
    LEARNING_RATE = 1e-4
    IMG_SIZE = (256, 256) # (512, 512) Using smaller images for faster training for now

    # Set device 
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create dataloaders
    train_loader, val_loader = get_dataloaders(
        data_dir=DATA_DIR,
        batch_size=BATCH_SIZE,
        img_size=IMG_SIZE,
        use_depth=False, # TODO: True if using depth info
        use_uv=True, # TODO: True if using UV maps
        use_border=True # TODO: True if using border masks for better training
    )

    # Visualize a batch (testing)
    sample_batch = next(iter(train_loader))
    print(f"Batch RGB shape: {sample_batch['rgb'].shape}")
    print(f"Batch GT coords: {sample_batch['coords'].shape}")
    print(f"Batch GT backward map: {sample_batch['uv'].shape}")
    if 'border' in sample_batch:
        print(f"Batch border mask shape: {sample_batch['border'].shape}")
    # visualize_batch(sample_batch) # Troubleshooting / sanity check 


    # Create model
    model = DewarpNet().to(device)
    print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")


    # TODO: TRY DIFFERENT OPTIMIZERS
    # optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

    # Training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    for epoch in range(NUM_EPOCHS):
        # Epoch timing check
        start_epoch = time.time()
        
        print(f"\n{'='*50}")
        print(f"EPOCH {epoch+1}/{NUM_EPOCHS}")
        print(f"{'='*50}")

        # TRAIN 
        model.train()
        total_train_loss = 0.0
        for batch in train_loader:
            rgb = batch['rgb'].to(device)
            C_gt = batch['coords'].to(device)
            B_gt = batch['uv'].to(device)
            D_gt = batch['ground_truth'].to(device)
            mask = batch.get('border', None)
            if mask is not None:
                mask = mask.to(device)

            optimizer.zero_grad()
            D_hat, C_hat, B_hat = model(rgb)

            losses = compute_loss(
                C_hat=C_hat, C=C_gt,
                B_hat=B_hat, B=B_gt,
                D_hat=D_hat, D=D_gt,
                mask=mask,
                alpha=1.0,
                beta=1.0,
                lambda_grad=0.1,
                gamma=1.0,
                delta=1.0
            )
            loss = losses['total']
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f"Train Loss: {avg_train_loss:.4f}")

        # --- VALIDATE ---
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                rgb = batch['rgb'].to(device)
                C_gt = batch['coords'].to(device)
                B_gt = batch['uv'].to(device)
                D_gt = batch['ground_truth'].to(device)
                mask = batch.get('border', None)
                if mask is not None:
                    mask = mask.to(device)

                D_hat, C_hat, B_hat = model(rgb)
                losses = compute_loss(
                    C_hat=C_hat, C=C_gt,
                    B_hat=B_hat, B=B_gt,
                    D_hat=D_hat, D=D_gt,
                    mask=mask
                )
                total_val_loss += losses['total'].item()

        avg_val_loss = total_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        print(f"Validation Loss: {avg_val_loss:.4f}")

        # Save the best model 
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_dewarpnet_model.pth')
            print(f"Saved best model with val loss {val_loss:.4f}")

        epoch_time = time.time() - start_epoch
        print(f"Epoch time: {epoch_time:.4f}s")

    print("\nTraining complete!")
    print(f"Best validation loss: {best_val_loss:.4f}")


    # Loss curve 
    plt.figure(figsize=(8, 5))
    epochs = range(1, NUM_EPOCHS + 1)
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('DewarpNet Training & Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    sample_rgb = sample_batch['rgb'].to(device)
    with torch.no_grad():
        D_hat, _, _ = model(sample_rgb)
    img = D_hat[0].cpu().permute(1, 2, 0).numpy()
    plt.figure(figsize=(5,5))
    plt.imshow(img)
    plt.title('Example Predicted Unwarped Image')
    plt.axis('off')
    plt.show()

    

In [None]:
# Execute main 

print("Document Reconstruction Dataset Loader")
print("="*50)

# Quick test 
train_loader, val_loader = get_dataloaders(
    data_dir = '/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep',
    batch_size = 4, 
    img_size=(256, 256) # (512,512)
)

print("\nDataset loaded successfully!")

# # Visualize sample batch 
# print("\nVisualizing a sample batch ...")
# sample_batch = next(iter(train_loader))
# print(f"Batch shape - RGB: {sample_batch['rgb'].shape}, Ground Truth: {sample_batch['ground_truth'].shape}")
# visualize_batch(sample_batch, num_samples=min(4, sample_batch['rgb'].shape[0]))

# Main training loop 
main()

In [None]:
# Loss plot 


history = torch.load('training_history.pth')

train_losses = history['train_losses']
val_losses = history['val_losses']
epochs = range(1, 50 + 1)

plt.figure(figsize=(8, 5))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Validation Loss')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training & Validation Loss Curve')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()