# Document aligner project
## NEW Model Implementation
### LAST ITERATION BEFORE FINAL 
- U-Net with skip connections
- Flow predictor, differentiable warping
- Standard encoder-decoder

In [None]:
from dataset_loader import get_dataloaders, visualize_batch, visualize_flow_and_warp, visualize_batch_pred
from reconstruction_model import ResNetUnet
from reconstruction_model import MaskedL1Loss, MaskedMSELoss, SSIMLoss, UVReconstructionLoss
from training_val import train_one_epoch, validate

import time 
import os
import glob
from typing import Dict, List, Tuple, Optional, Callable
from pathlib import Path

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np


In [None]:
# Create dataloaders
train_loader, val_loader = get_dataloaders(
    data_dir='/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep',
    batch_size=8,
    train_split=0.8,
    img_size=(512, 512), #(256, 256)
)

# Visualize samples
sample_batch = next(iter(train_loader))
# visualize_batch(sample_batch, num_samples=4)

# Model description
- TODO

Training:
- [x] Additional metrics (PSNR, SSIM)
- [x] Learning rate scheduling
- [ ] Gradient clipping
- [ ] Mixed precision training
- [ ] Logging to tensorboard/wandb

Main training loop
- [x] Implement better model architecture
- [x] Try different loss functions
- [x] Add learning rate scheduling
- [ ] Implement early stopping
- [ ] Add visualization and logging
- [ ] Experiment with data augmentation
- [ ] Use pretrained model from HuggingFace
- [ ] Enable MASKED LOSSES
- [ ] Use DEPTH
- [x] Use UV
- [x] Use BORDER
- [x] Try different OPTIMIZERS

## Training Loop

In [None]:
# Main training loop 

def main():
    """
    Main training loop 

    *** CHECK TODO LIST ***

    """
    
    # Configuration
    DATA_DIR = '/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep'
    BATCH_SIZE = 8
    NUM_EPOCHS = 50
    LEARNING_RATE = 1e-4
    IMG_SIZE = (256, 256) # (512, 512) Using smaller images for faster training for now

    # Set device 
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create dataloaders
    train_loader, val_loader = get_dataloaders(
        data_dir=DATA_DIR,
        batch_size=BATCH_SIZE,
        img_size=IMG_SIZE,
        use_depth=False, # TODO: True if using depth info
        use_uv=True, # TODO: True if using UV maps
        use_border=True # TODO: True if using border masks for better training
    )

    # Create model
    # model = DocumentReconstructionModel().to(device)
    model = ResNetUnet(
        backbone_name='resnet34',
        pretrained=True
    ).to(device)
    print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Visualize a batch (testing)
    sample_batch = next(iter(train_loader))
    print(f"Batch RGB shape: {sample_batch['rgb'].shape}")
    print(f"Batch GT shape: {sample_batch['ground_truth'].shape}")
    if 'border' in sample_batch:
        print(f"Batch border mask shape: {sample_batch['border'].shape}")
    if 'uv' in sample_batch:
        print(f"Batch UV shape: {sample_batch['uv'].shape}")
    # visualize_batch(sample_batch)  # Optional visualization
    visualize_flow_and_warp(model, sample_batch, device, num_samples=4)


    # TODO: TRY DIFFERENT LOSS FUNCTIONS
    # Option 1: simple losses (baseline, just to get it working)
    # criterion = nn.MSELoss() # Basic L2 Loss - sensitive to lighting
    # criterion = nn.L1Loss() # L1 loss also sensitive to lighting

    # Option 2: SSIM loss (RECOMMENDED - structure instead of lighting)
    # criterion = SSIMLoss() # Might need a pip install

    # Option 3: Masked losses (doc pixel focus)
    # Make sure to do use_border = True above **
    # criterion = MaskedL1Loss(use_mask=True)
    # criterion = MaskedMSELoss(use_mask=True)

    # Option 4: Combined loss with UV supervision (GET TO THIS EVENTUALLY)
    # NOTE: need to set use_uv = True, use_border = True, 
    criterion = UVReconstructionLoss(
        reconstruction_weight=1.0,
        uv_weight=0.0, # Turning off for now 
        smoothness_weight=0.01, 
        use_mask=True,
        loss_type='ssim' # Use SSIM for geometric recon
    )

    # TODO: TRY DIFFERENT OPTIMIZERS
    # optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

    # Training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    for epoch in range(NUM_EPOCHS):
        # Epoch timing check
        start_epoch = time.time()
        
        print(f"\n{'='*50}")
        print(f"EPOCH {epoch+1}/{NUM_EPOCHS}")
        print(f"{'='*50}")

        # TRAIN 
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch)
        print(f"Train loss: {train_loss:.4f}")
        train_losses.append(train_loss)

        # VALIDATE 
        val_loss = validate(model, val_loader, criterion, device)
        print(f"Val Loss: {val_loss:.4f}")
        val_losses.append(val_loss)
        torch.save({
            'epoch': epoch,
            'train_losses': train_losses,
            'val_losses': val_losses
        }, 'training_history.pth')

        if (epoch + 1) % 5 == 0:
            model.eval()
            with torch.no_grad():
                sample_batch = next(iter(val_loader))

                
                rgb = sample_batch['rgb'].to(device)
                outputs = model(rgb, predict_uv=False)


                # # MASK SANITY CHECK
                # rgb_single = sample_batch['rgb'][0].cpu()       # [3, H, W]
                # border_single = sample_batch['border'][0].cpu() # [1, H, W]

                # mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
                # std  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
                # rgb_vis = torch.clamp(rgb_single * std + mean, 0, 1)  # [3, H, W]

                # rgb_np  = rgb_vis.permute(1,2,0).numpy()              # [H, W, 3]
                # mask_np = border_single.squeeze(0).numpy()            # [H, W]

                # plt.figure(figsize=(12,4))
                # plt.subplot(1,3,1); plt.imshow(rgb_np); plt.title("RGB"); plt.axis("off")
                # plt.subplot(1,3,2); plt.imshow(mask_np, cmap="gray"); plt.title("Raw mask"); plt.axis("off")
                # plt.subplot(1,3,3); plt.imshow(rgb_np); plt.imshow(mask_np, cmap="jet", alpha=0.3); plt.title("Overlay"); plt.axis("off")
                # plt.show()

                batch_vis = {
                    'rgb': sample_batch['rgb'],
                    'ground_truth': sample_batch['ground_truth'],
                    'predicted': outputs['warped'].cpu()
                }
    
                print(f"\n[Visualization] Epoch {epoch+1}")
                visualize_batch_pred(batch_vis, num_samples=4)
                visualize_flow_and_warp(model, sample_batch, device, num_samples=4)

        # Save the best model 
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Saved best model with val loss {val_loss:.4f}")

        epoch_time = time.time() - start_epoch
        print(f"Epoch time: {epoch_time:.4f}s")

    print("\nTraining complete!")
    print(f"Best validation loss: {best_val_loss:.4f}")


    # Loss curve 
    epochs = range(1, NUM_EPOCHS + 1)

    plt.figure(figsize=(8, 5))
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training & Validation Loss Curve')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()
    

In [None]:
# Execute main 

print("Document Reconstruction Dataset Loader")
print("="*50)

# Quick test 
train_loader, val_loader = get_dataloaders(
    data_dir = '/home/daniel-choate/Datasets/DocUnwarp/renders/synthetic_data_pitch_sweep',
    batch_size = 4, 
    img_size=(256, 256) # (512,512)
)

print("\nDataset loaded successfully!")

# # Visualize sample batch 
# print("\nVisualizing a sample batch ...")
# sample_batch = next(iter(train_loader))
# print(f"Batch shape - RGB: {sample_batch['rgb'].shape}, Ground Truth: {sample_batch['ground_truth'].shape}")
# visualize_batch(sample_batch, num_samples=min(4, sample_batch['rgb'].shape[0]))

# Main training loop 
main()

In [None]:
# Loss plot 
import matplotlib.pyplot as plt

history = torch.load('training_history.pth')

train_losses = history['train_losses']
val_losses = history['val_losses']
epochs = range(1, 31 + 1)

plt.figure(figsize=(8, 5))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Validation Loss')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training & Validation Loss Curve')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()